specifying download_root instead

This commit is contained in:
Jong Wook Kim 2021-08-08 23:30:38 -07:00
parent 53e1d0b6aa
commit 6599907d3e
1 changed files with 5 additions and 5 deletions

View File

@ -36,7 +36,7 @@ _MODELS = {
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
@ -97,8 +97,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
**kwargs (optional): Any
The corresponding kwargs for internal `_download` function.
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
@ -109,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], **kwargs)
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else: