Can specify root directory when loading model

This commit is contained in:
kcosta42 2021-08-06 10:29:40 +02:00 committed by Jong Wook Kim
parent ff339871f3
commit 53e1d0b6aa
1 changed files with 6 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import hashlib
import os
import urllib
import warnings
from typing import Union, List
from typing import Any, Union, List
import torch
from PIL import Image
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
"""Load a CLIP model
Parameters
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
**kwargs (optional): Any
The corresponding kwargs for internal `_download` function.
Returns
-------
model : torch.nn.Module
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
model_path = _download(_MODELS[name], **kwargs)
elif os.path.isfile(name):
model_path = name
else: