Can specify root directory when loading model (#136)

* Can specify root directory when loading model

* specifying download_root instead

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
Kevin Costa 2021-08-09 08:43:22 +02:00 committed by GitHub
parent ff339871f3
commit 22fde59cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -2,7 +2,7 @@ import hashlib
import os
import urllib
import warnings
from typing import Union, List
from typing import Any, Union, List
import torch
from PIL import Image
@ -36,7 +36,7 @@ _MODELS = {
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else: