Load model from path (#41)
* Load model from path * showing download progress in "MiB" * clip.load() now can take a file path Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
84
clip/clip.py
84
clip/clip.py
@ -53,19 +53,8 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
return download_target
|
||||
|
||||
|
||||
def available_models():
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
if name not in _MODELS:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
model_path = _download(_MODELS[name])
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
n_px = model.input_resolution.item()
|
||||
|
||||
transform = Compose([
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=Image.BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert("RGB"),
|
||||
@ -73,11 +62,57 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(model.state_dict()).to(device)
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, transform
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
@ -121,10 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
|
||||
model.float()
|
||||
|
||||
return model, transform
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
|
Reference in New Issue
Block a user