added RN50 checkpoint and non-JIT model implementation
This commit is contained in:
parent
ebd0e35aac
commit
6bc0bd8873
|
@ -29,9 +29,9 @@ import clip
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model, transform = clip.load("ViT-B/32", device=device)
|
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||||
|
|
||||||
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||||
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
|
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -53,10 +53,12 @@ The CLIP module `clip` provides the following methods:
|
||||||
|
|
||||||
Returns the name(s) of the available CLIP models.
|
Returns the name(s) of the available CLIP models.
|
||||||
|
|
||||||
#### `clip.load(name, device=...)`
|
#### `clip.load(name, device=..., jit=True)`
|
||||||
|
|
||||||
Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU.
|
Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU.
|
||||||
|
|
||||||
|
When `jit` is `False`, a non-JIT version of the model will be loaded.
|
||||||
|
|
||||||
#### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
|
#### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
|
||||||
|
|
||||||
Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model
|
Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model
|
||||||
|
|
27
clip.py
27
clip.py
|
@ -9,13 +9,14 @@ from PIL import Image
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model import build_model
|
||||||
from simple_tokenizer import SimpleTokenizer as _Tokenizer
|
from simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||||
|
|
||||||
__all__ = ["available_models", "load", "tokenize"]
|
__all__ = ["available_models", "load", "tokenize"]
|
||||||
_tokenizer = _Tokenizer()
|
_tokenizer = _Tokenizer()
|
||||||
|
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
|
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,14 +57,26 @@ def available_models():
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"):
|
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
||||||
if name not in _MODELS:
|
if name not in _MODELS:
|
||||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||||
|
|
||||||
model_path = _download(_MODELS[name])
|
model_path = _download(_MODELS[name])
|
||||||
model = torch.jit.load(model_path, map_location=device).eval()
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||||
n_px = model.input_resolution.item()
|
n_px = model.input_resolution.item()
|
||||||
|
|
||||||
|
transform = Compose([
|
||||||
|
Resize(n_px, interpolation=Image.BICUBIC),
|
||||||
|
CenterCrop(n_px),
|
||||||
|
lambda image: image.convert("RGB"),
|
||||||
|
ToTensor(),
|
||||||
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
])
|
||||||
|
|
||||||
|
if not jit:
|
||||||
|
model = build_model(model.state_dict()).to(device)
|
||||||
|
return model, transform
|
||||||
|
|
||||||
# patch the device names
|
# patch the device names
|
||||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||||
|
@ -106,14 +119,6 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
|
|
||||||
model.float()
|
model.float()
|
||||||
|
|
||||||
transform = Compose([
|
|
||||||
Resize(n_px, interpolation=Image.BICUBIC),
|
|
||||||
CenterCrop(n_px),
|
|
||||||
lambda image: image.convert("RGB"),
|
|
||||||
ToTensor(),
|
|
||||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
|
||||||
])
|
|
||||||
|
|
||||||
return model, transform
|
return model, transform
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ Please see the paper linked below for further details about their specification.
|
||||||
### Documents
|
### Documents
|
||||||
|
|
||||||
- [Blog Post](https://openai.com/blog/clip/)
|
- [Blog Post](https://openai.com/blog/clip/)
|
||||||
- [CLIP Paper](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language.pdf)
|
- [CLIP Paper](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language_Supervision.pdf)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,399 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = None
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||||
|
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||||
|
self.downsample = nn.Sequential(OrderedDict([
|
||||||
|
("-1", nn.AvgPool2d(stride)),
|
||||||
|
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||||
|
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.relu(self.bn2(self.conv2(out)))
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||||
|
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||||
|
x, _ = F.multi_head_attention_forward(
|
||||||
|
query=x, key=x, value=x,
|
||||||
|
embed_dim_to_check=x.shape[-1],
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
in_proj_weight=None,
|
||||||
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||||
|
bias_k=None,
|
||||||
|
bias_v=None,
|
||||||
|
add_zero_attn=False,
|
||||||
|
dropout_p=0,
|
||||||
|
out_proj_weight=self.c_proj.weight,
|
||||||
|
out_proj_bias=self.c_proj.bias,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
training=self.training,
|
||||||
|
need_weights=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return x[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||||
|
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||||
|
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||||
|
- The final pooling layer is a QKV attention instead of an average pool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
|
||||||
|
# the 3-layer stem
|
||||||
|
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||||
|
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(width)
|
||||||
|
self.avgpool = nn.AvgPool2d(2)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
# residual layers
|
||||||
|
self._inplanes = width # this is a *mutable* variable used during construction
|
||||||
|
self.layer1 = self._make_layer(width, layers[0])
|
||||||
|
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||||
|
|
||||||
|
embed_dim = width * 32 # the ResNet feature dimension
|
||||||
|
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||||
|
|
||||||
|
def _make_layer(self, planes, blocks, stride=1):
|
||||||
|
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||||
|
|
||||||
|
self._inplanes = planes * Bottleneck.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(Bottleneck(self._inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
def stem(x):
|
||||||
|
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||||
|
x = self.relu(bn(conv(x)))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.type(self.conv1.weight.dtype)
|
||||||
|
x = stem(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.attnpool(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
ret = super().forward(x.type(torch.float32))
|
||||||
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||||
|
self.ln_1 = LayerNorm(d_model)
|
||||||
|
self.mlp = nn.Sequential(OrderedDict([
|
||||||
|
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||||
|
("gelu", QuickGELU()),
|
||||||
|
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||||
|
]))
|
||||||
|
self.ln_2 = LayerNorm(d_model)
|
||||||
|
self.attn_mask = attn_mask
|
||||||
|
|
||||||
|
def attention(self, x: torch.Tensor):
|
||||||
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||||
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x + self.attention(self.ln_1(x))
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return self.resblocks(x)
|
||||||
|
|
||||||
|
|
||||||
|
class VisualTransformer(nn.Module):
|
||||||
|
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||||
|
|
||||||
|
scale = width ** -0.5
|
||||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||||
|
self.ln_pre = LayerNorm(width)
|
||||||
|
|
||||||
|
self.transformer = Transformer(width, layers, heads)
|
||||||
|
|
||||||
|
self.ln_post = LayerNorm(width)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||||
|
x = x + self.positional_embedding.to(x.dtype)
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
x = self.ln_post(x[:, 0, :])
|
||||||
|
|
||||||
|
if self.proj is not None:
|
||||||
|
x = x @ self.proj
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIP(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim: int,
|
||||||
|
# vision
|
||||||
|
image_resolution: int,
|
||||||
|
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||||
|
vision_width: int,
|
||||||
|
vision_patch_size: int,
|
||||||
|
# text
|
||||||
|
context_length: int,
|
||||||
|
vocab_size: int,
|
||||||
|
transformer_width: int,
|
||||||
|
transformer_heads: int,
|
||||||
|
transformer_layers: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.context_length = context_length
|
||||||
|
|
||||||
|
if isinstance(vision_layers, (tuple, list)):
|
||||||
|
vision_heads = vision_width * 32 // 64
|
||||||
|
self.visual = ModifiedResNet(
|
||||||
|
layers=vision_layers,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
heads=vision_heads,
|
||||||
|
input_resolution=image_resolution,
|
||||||
|
width=vision_width
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vision_heads = vision_width // 64
|
||||||
|
self.visual = VisualTransformer(
|
||||||
|
input_resolution=image_resolution,
|
||||||
|
patch_size=vision_patch_size,
|
||||||
|
width=vision_width,
|
||||||
|
layers=vision_layers,
|
||||||
|
heads=vision_heads,
|
||||||
|
output_dim=embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer = Transformer(
|
||||||
|
width=transformer_width,
|
||||||
|
layers=transformer_layers,
|
||||||
|
heads=transformer_heads,
|
||||||
|
attn_mask=self.build_attention_mask()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||||
|
self.ln_final = LayerNorm(transformer_width)
|
||||||
|
|
||||||
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||||
|
self.logit_scale = nn.Parameter(torch.ones([]))
|
||||||
|
|
||||||
|
def build_attention_mask(self):
|
||||||
|
# lazily create causal attention mask, with full attention between the vision tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(self.context_length, self.context_length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
return mask
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.visual.conv1.weight.dtype
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
return self.visual(image.type(self.dtype))
|
||||||
|
|
||||||
|
def encode_text(self, text):
|
||||||
|
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||||
|
|
||||||
|
x = x + self.positional_embedding.type(self.dtype)
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.ln_final(x).type(self.dtype)
|
||||||
|
|
||||||
|
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
image_features = self.encode_image(image)
|
||||||
|
text_features = self.encode_text(text)
|
||||||
|
|
||||||
|
# normalized features
|
||||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# cosine similarity as logits
|
||||||
|
logit_scale = self.logit_scale.exp()
|
||||||
|
logits_per_iamge = logit_scale * image_features @ text_features.t()
|
||||||
|
logits_per_text = logit_scale * text_features @ image_features.t()
|
||||||
|
|
||||||
|
# shape = [global_batch_size, global_batch_size]
|
||||||
|
return logits_per_iamge, logits_per_text
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights(model: nn.Module):
|
||||||
|
"""Convert applicable model parameters to fp16"""
|
||||||
|
|
||||||
|
def _convert_weights_to_fp16(l):
|
||||||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||||
|
l.weight.data = l.weight.data.half()
|
||||||
|
if l.bias is not None:
|
||||||
|
l.bias.data = l.bias.data.half()
|
||||||
|
|
||||||
|
if isinstance(l, nn.MultiheadAttention):
|
||||||
|
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||||
|
tensor = getattr(l, attr)
|
||||||
|
if tensor is not None:
|
||||||
|
tensor.data = tensor.data.half()
|
||||||
|
|
||||||
|
for name in ["text_projection", "proj"]:
|
||||||
|
if hasattr(l, name):
|
||||||
|
attr = getattr(l, name)
|
||||||
|
if attr is not None:
|
||||||
|
attr.data = attr.data.half()
|
||||||
|
|
||||||
|
model.apply(_convert_weights_to_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(state_dict: dict):
|
||||||
|
vit = "visual.proj" in state_dict
|
||||||
|
|
||||||
|
if vit:
|
||||||
|
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||||
|
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||||
|
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||||
|
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
image_resolution = vision_patch_size * grid_size
|
||||||
|
else:
|
||||||
|
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||||
|
vision_layers = tuple(counts)
|
||||||
|
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||||
|
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||||
|
vision_patch_size = None
|
||||||
|
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||||
|
image_resolution = output_width * 32
|
||||||
|
|
||||||
|
embed_dim = state_dict["text_projection"].shape[1]
|
||||||
|
context_length = state_dict["positional_embedding"].shape[0]
|
||||||
|
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||||
|
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||||
|
transformer_heads = transformer_width // 64
|
||||||
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||||
|
|
||||||
|
model = CLIP(
|
||||||
|
embed_dim,
|
||||||
|
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||||
|
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
convert_weights(model)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model.eval()
|
Loading…
Reference in New Issue