diff --git a/clip/model.py b/clip/model.py index 422a34a..7c7832d 100644 --- a/clip/model.py +++ b/clip/model.py @@ -365,7 +365,7 @@ class CLIP(nn.Module): logits_per_text = logit_scale * text_features @ image_features.t() # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text + return logits_per_image, logits_per_text, logit_scale def convert_weights(model: nn.Module):