specifying download_root instead

This commit is contained in:
Jong Wook Kim 2021-08-08 23:30:38 -07:00
parent 53e1d0b6aa
commit 6599907d3e
1 changed files with 5 additions and 5 deletions

View File

@ -36,7 +36,7 @@ _MODELS = {
} }
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _download(url: str, root: str):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any): def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model """Load a CLIP model
Parameters Parameters
@ -97,8 +97,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default). Whether to load the optimized JIT model or more hackable non-JIT model (default).
**kwargs (optional): Any download_root: str
The corresponding kwargs for internal `_download` function. path to download the model files; by default, it uses "~/.cache/clip"
Returns Returns
------- -------
@ -109,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
""" """
if name in _MODELS: if name in _MODELS:
model_path = _download(_MODELS[name], **kwargs) model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name): elif os.path.isfile(name):
model_path = name model_path = name
else: else: