From 424e55c2bafb81ea3891fdce151420dc5f82f1a1 Mon Sep 17 00:00:00 2001 From: DeePPenetration <74469343+MaxMood69@users.noreply.github.com> Date: Sun, 4 Apr 2021 20:56:08 +0500 Subject: [PATCH] Revert "Load model from path (#41)" This reverts commit 4c0275784d6d9da97ca1f47eaaee31de1867da91. --- README.md | 4 +-- clip/clip.py | 84 +++++++++++---------------------------------------- clip/model.py | 3 +- 3 files changed, 20 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index b5287fd..9d78b47 100644 --- a/README.md +++ b/README.md @@ -56,9 +56,9 @@ Returns the names of the available CLIP models. #### `clip.load(name, device=..., jit=True)` -Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. +Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. -The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. +When `jit` is `False`, a non-JIT version of the model will be loaded. #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` diff --git a/clip/clip.py b/clip/clip.py index 76f241b..94641cd 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -55,8 +55,19 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): return download_target -def _transform(n_px): - return Compose([ +def available_models(): + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + if name not in _MODELS: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + model_path = _download(_MODELS[name]) + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + n_px = model.input_resolution.item() + + transform = Compose([ Resize(n_px, interpolation=Image.BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), @@ -64,57 +75,11 @@ def _transform(n_px): Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) - -def available_models() -> List[str]: - """Returns the names of available CLIP models""" - return list(_MODELS.keys()) - - -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_path = _download(_MODELS[name]) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(model_path, map_location="cpu") - if not jit: - model = build_model(state_dict or model.state_dict()).to(device) + model = build_model(model.state_dict()).to(device) if str(device) == "cpu": model.float() - return model, _transform(model.visual.input_resolution) + return model, transform # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) @@ -158,25 +123,10 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a model.float() - return model, _transform(model.input_resolution.item()) + return model, transform -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ +def tokenize(texts: Union[str, List[str]], context_length: int = 77): if isinstance(texts, str): texts = [texts] diff --git a/clip/model.py b/clip/model.py index 422a34a..fde63f2 100644 --- a/clip/model.py +++ b/clip/model.py @@ -424,8 +424,7 @@ def build_model(state_dict: dict): ) for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] + del state_dict[key] convert_weights(model) model.load_state_dict(state_dict)