Load model from path
This commit is contained in:
parent
e5347713f4
commit
6dab2f989b
14
clip/clip.py
14
clip/clip.py
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
|||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
__all__ = ["available_models", "load", "load_from_file", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
_MODELS = {
|
||||
|
@ -58,10 +58,22 @@ def available_models():
|
|||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
""" Load by model name
|
||||
Kept function name 'load' for backwards compatability
|
||||
"""
|
||||
if name not in _MODELS:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
model_path = _download(_MODELS[name])
|
||||
|
||||
return load_from_file(model_path, device)
|
||||
|
||||
|
||||
def load_from_file(model_path: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
""" Load model file
|
||||
Original 'load' function
|
||||
"""
|
||||
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
n_px = model.input_resolution.item()
|
||||
|
||||
|
|
Loading…
Reference in New Issue