Can specify root directory when loading model
This commit is contained in:
parent
ff339871f3
commit
53e1d0b6aa
|
@ -2,7 +2,7 @@ import hashlib
|
|||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Union, List
|
||||
from typing import Any, Union, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
@ -83,7 +83,7 @@ def available_models() -> List[str]:
|
|||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
|
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
**kwargs (optional): Any
|
||||
The corresponding kwargs for internal `_download` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
|
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
model_path = _download(_MODELS[name], **kwargs)
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue