Can specify root directory when loading model
This commit is contained in:
parent
ff339871f3
commit
53e1d0b6aa
|
@ -2,7 +2,7 @@ import hashlib
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Union, List
|
from typing import Any, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -83,7 +83,7 @@ def available_models() -> List[str]:
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
|
||||||
"""Load a CLIP model
|
"""Load a CLIP model
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -97,6 +97,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
jit : bool
|
jit : bool
|
||||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||||
|
|
||||||
|
**kwargs (optional): Any
|
||||||
|
The corresponding kwargs for internal `_download` function.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
model : torch.nn.Module
|
model : torch.nn.Module
|
||||||
|
@ -106,7 +109,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||||
"""
|
"""
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
model_path = _download(_MODELS[name])
|
model_path = _download(_MODELS[name], **kwargs)
|
||||||
elif os.path.isfile(name):
|
elif os.path.isfile(name):
|
||||||
model_path = name
|
model_path = name
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue