From 6599907d3e3c809b8c1a85fa317774241c0b4740 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sun, 8 Aug 2021 23:30:38 -0700 Subject: [PATCH] specifying download_root instead --- clip/clip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index e3202e5..6ce5565 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -36,7 +36,7 @@ _MODELS = { } -def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): +def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) @@ -83,7 +83,7 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any): +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): """Load a CLIP model Parameters @@ -97,8 +97,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). - **kwargs (optional): Any - The corresponding kwargs for internal `_download` function. + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" Returns ------- @@ -109,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: - model_path = _download(_MODELS[name], **kwargs) + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) elif os.path.isfile(name): model_path = name else: