Can specify root directory when loading model

This commit is contained in:
kcosta42 2021-08-06 10:29:40 +02:00 committed by Jong Wook Kim
parent ff339871f3
commit 53e1d0b6aa
1 changed files with 6 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import hashlib
import os import os
import urllib import urllib
import warnings import warnings
from typing import Union, List from typing import Any, Union, List
import torch import torch
from PIL import Image from PIL import Image
@ -83,7 +83,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
"""Load a CLIP model """Load a CLIP model
Parameters Parameters
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit : bool jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default). Whether to load the optimized JIT model or more hackable non-JIT model (default).
**kwargs (optional): Any
The corresponding kwargs for internal `_download` function.
Returns Returns
------- -------
model : torch.nn.Module model : torch.nn.Module
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
""" """
if name in _MODELS: if name in _MODELS:
model_path = _download(_MODELS[name]) model_path = _download(_MODELS[name], **kwargs)
elif os.path.isfile(name): elif os.path.isfile(name):
model_path = name model_path = name
else: else: