diff --git a/clip/clip.py b/clip/clip.py index 0f6c99c..6ce5565 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -2,7 +2,7 @@ import hashlib import os import urllib import warnings -from typing import Union, List +from typing import Any, Union, List import torch from PIL import Image @@ -36,7 +36,7 @@ _MODELS = { } -def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): +def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) @@ -83,7 +83,7 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): """Load a CLIP model Parameters @@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + Returns ------- model : torch.nn.Module @@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: - model_path = _download(_MODELS[name]) + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) elif os.path.isfile(name): model_path = name else: