Using non-JIT by default; compat fix with 1.8+

This commit is contained in:
Jong Wook Kim 2021-07-18 18:45:21 -07:00
parent cfcffb90e6
commit db20393f4a
2 changed files with 26 additions and 7 deletions

View File

@ -12,6 +12,17 @@ from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
@ -57,7 +68,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
@ -70,7 +81,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
"""Load a CLIP model
Parameters
@ -82,7 +93,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Returns
-------
@ -121,7 +132,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
@ -141,7 +156,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

View File

@ -1,5 +1,5 @@
ftfy
regex
tqdm
torch~=1.7.1
torchvision~=0.8.2
torch
torchvision