Can specify root directory when loading model (#136)
* Can specify root directory when loading model * specifying download_root instead Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
ff339871f3
commit
22fde59cbe
11
clip/clip.py
11
clip/clip.py
|
@ -2,7 +2,7 @@ import hashlib
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Union, List
|
from typing import Any, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -36,7 +36,7 @@ _MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
def _download(url: str, root: str):
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
filename = os.path.basename(url)
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def available_models() -> List[str]:
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||||
"""Load a CLIP model
|
"""Load a CLIP model
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
jit : bool
|
jit : bool
|
||||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||||
|
|
||||||
|
download_root: str
|
||||||
|
path to download the model files; by default, it uses "~/.cache/clip"
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
model : torch.nn.Module
|
model : torch.nn.Module
|
||||||
|
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||||
"""
|
"""
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
model_path = _download(_MODELS[name])
|
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||||
elif os.path.isfile(name):
|
elif os.path.isfile(name):
|
||||||
model_path = name
|
model_path = name
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue