Load model from path

This commit is contained in:
Sebastian Berns 2021-02-11 21:17:11 +00:00
parent e5347713f4
commit 6dab2f989b
1 changed files with 13 additions and 1 deletions

View File

@ -12,7 +12,7 @@ from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"]
__all__ = ["available_models", "load", "load_from_file", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
@ -58,10 +58,22 @@ def available_models():
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
""" Load by model name
Kept function name 'load' for backwards compatability
"""
if name not in _MODELS:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
model_path = _download(_MODELS[name])
return load_from_file(model_path, device)
def load_from_file(model_path: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
""" Load model file
Original 'load' function
"""
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
n_px = model.input_resolution.item()