Load model from path
This commit is contained in:
parent
e5347713f4
commit
6dab2f989b
14
clip/clip.py
14
clip/clip.py
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
||||||
from .model import build_model
|
from .model import build_model
|
||||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||||
|
|
||||||
__all__ = ["available_models", "load", "tokenize"]
|
__all__ = ["available_models", "load", "load_from_file", "tokenize"]
|
||||||
_tokenizer = _Tokenizer()
|
_tokenizer = _Tokenizer()
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
|
@ -58,10 +58,22 @@ def available_models():
|
||||||
|
|
||||||
|
|
||||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||||
|
""" Load by model name
|
||||||
|
Kept function name 'load' for backwards compatability
|
||||||
|
"""
|
||||||
if name not in _MODELS:
|
if name not in _MODELS:
|
||||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||||
|
|
||||||
model_path = _download(_MODELS[name])
|
model_path = _download(_MODELS[name])
|
||||||
|
|
||||||
|
return load_from_file(model_path, device)
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_file(model_path: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||||
|
""" Load model file
|
||||||
|
Original 'load' function
|
||||||
|
"""
|
||||||
|
|
||||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||||
n_px = model.input_resolution.item()
|
n_px = model.input_resolution.item()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue