diff --git a/clip/model.py b/clip/model.py index 1ddd908..9cf262a 100644 --- a/clip/model.py +++ b/clip/model.py @@ -1,6 +1,7 @@ from collections import OrderedDict from typing import Tuple, Union +import numpy as np import torch import torch.nn.functional as F from torch import nn @@ -287,7 +288,7 @@ class CLIP(nn.Module): self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([])) + self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1/0.07)])) self.initialize_parameters()