diff --git a/clip/model.py b/clip/model.py index 9cf262a..422a34a 100644 --- a/clip/model.py +++ b/clip/model.py @@ -288,7 +288,7 @@ class CLIP(nn.Module): self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1/0.07)])) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters()