specifying download_root instead
This commit is contained in:
parent
53e1d0b6aa
commit
6599907d3e
10
clip/clip.py
10
clip/clip.py
|
@ -36,7 +36,7 @@ _MODELS = {
|
|||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
|
@ -83,7 +83,7 @@ def available_models() -> List[str]:
|
|||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
|
@ -97,8 +97,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
**kwargs (optional): Any
|
||||
The corresponding kwargs for internal `_download` function.
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/clip"
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -109,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name], **kwargs)
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue