Load model from path (#41)
* Load model from path * showing download progress in "MiB" * clip.load() now can take a file path Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
		| @ -56,9 +56,9 @@ Returns the names of the available CLIP models. | |||||||
|  |  | ||||||
| #### `clip.load(name, device=..., jit=True)` | #### `clip.load(name, device=..., jit=True)` | ||||||
|  |  | ||||||
| Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. | ||||||
|  |  | ||||||
| When `jit` is `False`, a non-JIT version of the model will be loaded. | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. | ||||||
|  |  | ||||||
| #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										84
									
								
								clip/clip.py
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								clip/clip.py
									
									
									
									
									
								
							| @ -53,19 +53,8 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): | |||||||
|     return download_target |     return download_target | ||||||
|  |  | ||||||
|  |  | ||||||
| def available_models(): | def _transform(n_px): | ||||||
|     return list(_MODELS.keys()) |     return Compose([ | ||||||
|  |  | ||||||
|  |  | ||||||
| def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): |  | ||||||
|     if name not in _MODELS: |  | ||||||
|         raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |  | ||||||
|  |  | ||||||
|     model_path = _download(_MODELS[name]) |  | ||||||
|     model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() |  | ||||||
|     n_px = model.input_resolution.item() |  | ||||||
|  |  | ||||||
|     transform = Compose([ |  | ||||||
|         Resize(n_px, interpolation=Image.BICUBIC), |         Resize(n_px, interpolation=Image.BICUBIC), | ||||||
|         CenterCrop(n_px), |         CenterCrop(n_px), | ||||||
|         lambda image: image.convert("RGB"), |         lambda image: image.convert("RGB"), | ||||||
| @ -73,11 +62,57 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a | |||||||
|         Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |         Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | ||||||
|     ]) |     ]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def available_models() -> List[str]: | ||||||
|  |     """Returns the names of available CLIP models""" | ||||||
|  |     return list(_MODELS.keys()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): | ||||||
|  |     """Load a CLIP model | ||||||
|  |  | ||||||
|  |     Parameters | ||||||
|  |     ---------- | ||||||
|  |     name : str | ||||||
|  |         A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict | ||||||
|  |  | ||||||
|  |     device : Union[str, torch.device] | ||||||
|  |         The device to put the loaded model | ||||||
|  |  | ||||||
|  |     jit : bool | ||||||
|  |         Whether to load the optimized JIT model (default) or more hackable non-JIT model. | ||||||
|  |  | ||||||
|  |     Returns | ||||||
|  |     ------- | ||||||
|  |     model : torch.nn.Module | ||||||
|  |         The CLIP model | ||||||
|  |  | ||||||
|  |     preprocess : Callable[[PIL.Image], torch.Tensor] | ||||||
|  |         A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input | ||||||
|  |     """ | ||||||
|  |     if name in _MODELS: | ||||||
|  |         model_path = _download(_MODELS[name]) | ||||||
|  |     elif os.path.isfile(name): | ||||||
|  |         model_path = name | ||||||
|  |     else: | ||||||
|  |         raise RuntimeError(f"Model {name} not found; available models = {available_models()}") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # loading JIT archive | ||||||
|  |         model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() | ||||||
|  |         state_dict = None | ||||||
|  |     except RuntimeError: | ||||||
|  |         # loading saved state dict | ||||||
|  |         if jit: | ||||||
|  |             warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") | ||||||
|  |             jit = False | ||||||
|  |         state_dict = torch.load(model_path, map_location="cpu") | ||||||
|  |  | ||||||
|     if not jit: |     if not jit: | ||||||
|         model = build_model(model.state_dict()).to(device) |         model = build_model(state_dict or model.state_dict()).to(device) | ||||||
|         if str(device) == "cpu": |         if str(device) == "cpu": | ||||||
|             model.float() |             model.float() | ||||||
|         return model, transform |         return model, _transform(model.visual.input_resolution) | ||||||
|  |  | ||||||
|     # patch the device names |     # patch the device names | ||||||
|     device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) |     device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | ||||||
| @ -121,10 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a | |||||||
|  |  | ||||||
|         model.float() |         model.float() | ||||||
|  |  | ||||||
|     return model, transform |     return model, _transform(model.input_resolution.item()) | ||||||
|  |  | ||||||
|  |  | ||||||
| def tokenize(texts: Union[str, List[str]], context_length: int = 77): | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: | ||||||
|  |     """ | ||||||
|  |     Returns the tokenized representation of given input string(s) | ||||||
|  |  | ||||||
|  |     Parameters | ||||||
|  |     ---------- | ||||||
|  |     texts : Union[str, List[str]] | ||||||
|  |         An input string or a list of input strings to tokenize | ||||||
|  |  | ||||||
|  |     context_length : int | ||||||
|  |         The context length to use; all CLIP models use 77 as the context length | ||||||
|  |  | ||||||
|  |     Returns | ||||||
|  |     ------- | ||||||
|  |     A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | ||||||
|  |     """ | ||||||
|     if isinstance(texts, str): |     if isinstance(texts, str): | ||||||
|         texts = [texts] |         texts = [texts] | ||||||
|  |  | ||||||
|  | |||||||
| @ -423,6 +423,7 @@ def build_model(state_dict: dict): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     for key in ["input_resolution", "context_length", "vocab_size"]: |     for key in ["input_resolution", "context_length", "vocab_size"]: | ||||||
|  |         if key in state_dict: | ||||||
|             del state_dict[key] |             del state_dict[key] | ||||||
|  |  | ||||||
|     convert_weights(model) |     convert_weights(model) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user