Can specify root directory when loading model (#136)

* Can specify root directory when loading model

* specifying download_root instead

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
Kevin Costa 2021-08-09 08:43:22 +02:00 committed by GitHub
parent ff339871f3
commit 22fde59cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -2,7 +2,7 @@ import hashlib
import os import os
import urllib import urllib
import warnings import warnings
from typing import Union, List from typing import Any, Union, List
import torch import torch
from PIL import Image from PIL import Image
@ -36,7 +36,7 @@ _MODELS = {
} }
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _download(url: str, root: str):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model """Load a CLIP model
Parameters Parameters
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default). Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns Returns
------- -------
model : torch.nn.Module model : torch.nn.Module
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
""" """
if name in _MODELS: if name in _MODELS:
model_path = _download(_MODELS[name]) model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name): elif os.path.isfile(name):
model_path = name model_path = name
else: else: