diff --git a/clip/clip.py b/clip/clip.py index 8f37ba8..dfb3d1e 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -12,7 +12,7 @@ from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer -__all__ = ["available_models", "load", "tokenize"] +__all__ = ["available_models", "load", "load_from_file", "tokenize"] _tokenizer = _Tokenizer() _MODELS = { @@ -58,10 +58,22 @@ def available_models(): def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """ Load by model name + Kept function name 'load' for backwards compatability + """ if name not in _MODELS: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") model_path = _download(_MODELS[name]) + + return load_from_file(model_path, device) + + +def load_from_file(model_path: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """ Load model file + Original 'load' function + """ + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() n_px = model.input_resolution.item()