diff --git a/clip/model.py b/clip/model.py index f2c95c4..f7958f1 100644 --- a/clip/model.py +++ b/clip/model.py @@ -362,7 +362,7 @@ class CLIP(nn.Module): # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logit_scale * text_features @ image_features.t() + logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text