Merge 424e55c2ba
into 8a665a683d
This commit is contained in:
commit
3bdc6a7ee4
|
@ -56,9 +56,9 @@ Returns the names of the available CLIP models.
|
|||
|
||||
#### `clip.load(name, device=..., jit=True)`
|
||||
|
||||
Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint.
|
||||
Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU.
|
||||
|
||||
The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded.
|
||||
When `jit` is `False`, a non-JIT version of the model will be loaded.
|
||||
|
||||
#### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
|
||||
|
||||
|
|
84
clip/clip.py
84
clip/clip.py
|
@ -55,8 +55,19 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
|||
return download_target
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
def available_models():
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
if name not in _MODELS:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
model_path = _download(_MODELS[name])
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
n_px = model.input_resolution.item()
|
||||
|
||||
transform = Compose([
|
||||
Resize(n_px, interpolation=Image.BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert("RGB"),
|
||||
|
@ -64,57 +75,11 @@ def _transform(n_px):
|
|||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
model = build_model(model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
return model, transform
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
|
@ -158,25 +123,10 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
|
||||
model.float()
|
||||
|
||||
return model, _transform(model.input_resolution.item())
|
||||
return model, transform
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
|
|
|
@ -424,8 +424,7 @@ def build_model(state_dict: dict):
|
|||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
|
|
Loading…
Reference in New Issue