Can specify root directory when loading model (#136)
* Can specify root directory when loading model * specifying download_root instead Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
ff339871f3
commit
22fde59cbe
11
clip/clip.py
11
clip/clip.py
|
@ -2,7 +2,7 @@ import hashlib
|
|||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Union, List
|
||||
from typing import Any, Union, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
@ -36,7 +36,7 @@ _MODELS = {
|
|||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
|
@ -83,7 +83,7 @@ def available_models() -> List[str]:
|
|||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
|
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/clip"
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
|
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue