From db20393f4affd4158528bd868478e516ebed0944 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sun, 18 Jul 2021 18:45:21 -0700 Subject: [PATCH] Using non-JIT by default; compat fix with 1.8+ --- clip/clip.py | 29 ++++++++++++++++++++++++----- requirements.txt | 4 ++-- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 76f241b..55e1433 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -12,6 +12,17 @@ from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if torch.__version__.split(".") < ["1", "7", "1"]: + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() @@ -57,7 +68,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _transform(n_px): return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), + Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), ToTensor(), @@ -70,7 +81,7 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): """Load a CLIP model Parameters @@ -82,7 +93,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a The device to put the loaded model jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. + Whether to load the optimized JIT model or more hackable non-JIT model (default). Returns ------- @@ -121,7 +132,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_device(module): - graphs = [module.graph] if hasattr(module, "graph") else [] + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + if hasattr(module, "forward1"): graphs.append(module.forward1.graph) @@ -141,7 +156,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a float_node = float_input.node() def patch_float(module): - graphs = [module.graph] if hasattr(module, "graph") else [] + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + if hasattr(module, "forward1"): graphs.append(module.forward1.graph) diff --git a/requirements.txt b/requirements.txt index f6f8b44..6b98c33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ ftfy regex tqdm -torch~=1.7.1 -torchvision~=0.8.2 +torch +torchvision