Load model from path (#41)

* Load model from path

* showing download progress in "MiB"

* clip.load() now can take a file path

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
Sebastian Berns
2021-02-16 11:19:42 +00:00
committed by GitHub
parent 8f6deb52a1
commit 4c0275784d
3 changed files with 71 additions and 20 deletions

View File

@ -53,19 +53,8 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
return download_target
def available_models():
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
if name not in _MODELS:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
model_path = _download(_MODELS[name])
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
n_px = model.input_resolution.item()
transform = Compose([
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
@ -73,11 +62,57 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(model.state_dict()).to(device)
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, transform
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
@ -121,10 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
model.float()
return model, transform
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]