diff --git a/README.md b/README.md index bd1a971..0bb0bf8 100644 --- a/README.md +++ b/README.md @@ -56,9 +56,9 @@ Returns the names of the available CLIP models. #### `clip.load(name, device=..., jit=True)` -Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. +Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. -When `jit` is `False`, a non-JIT version of the model will be loaded. +The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` diff --git a/clip/clip.py b/clip/clip.py index 5f16f41..5c21a21 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -53,19 +53,8 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): return download_target -def available_models(): - return list(_MODELS.keys()) - - -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): - if name not in _MODELS: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - model_path = _download(_MODELS[name]) - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - n_px = model.input_resolution.item() - - transform = Compose([ +def _transform(n_px): + return Compose([ Resize(n_px, interpolation=Image.BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), @@ -73,11 +62,57 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + if not jit: - model = build_model(model.state_dict()).to(device) + model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu": model.float() - return model, transform + return model, _transform(model.visual.input_resolution) # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) @@ -121,10 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a model.float() - return model, transform + return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77): +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ if isinstance(texts, str): texts = [texts] diff --git a/clip/model.py b/clip/model.py index f2c998a..1ddd908 100644 --- a/clip/model.py +++ b/clip/model.py @@ -423,7 +423,8 @@ def build_model(state_dict: dict): ) for key in ["input_resolution", "context_length", "vocab_size"]: - del state_dict[key] + if key in state_dict: + del state_dict[key] convert_weights(model) model.load_state_dict(state_dict)