Using non-JIT by default; compat fix with 1.8+
This commit is contained in:
parent
cfcffb90e6
commit
db20393f4a
29
clip/clip.py
29
clip/clip.py
|
@ -12,6 +12,17 @@ from tqdm import tqdm
|
|||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
if torch.__version__.split(".") < ["1", "7", "1"]:
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
|
@ -57,7 +68,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
|||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=Image.BICUBIC),
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert("RGB"),
|
||||
ToTensor(),
|
||||
|
@ -70,7 +81,7 @@ def available_models() -> List[str]:
|
|||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
|
@ -82,7 +93,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -121,7 +132,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
|
@ -141,7 +156,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
ftfy
|
||||
regex
|
||||
tqdm
|
||||
torch~=1.7.1
|
||||
torchvision~=0.8.2
|
||||
torch
|
||||
torchvision
|
||||
|
|
Loading…
Reference in New Issue