Update model.py
This commit is contained in:
parent
cfcffb90e6
commit
d038cfaf59
|
@ -365,7 +365,7 @@ class CLIP(nn.Module):
|
||||||
logits_per_text = logit_scale * text_features @ image_features.t()
|
logits_per_text = logit_scale * text_features @ image_features.t()
|
||||||
|
|
||||||
# shape = [global_batch_size, global_batch_size]
|
# shape = [global_batch_size, global_batch_size]
|
||||||
return logits_per_image, logits_per_text
|
return logits_per_image, logits_per_text, logit_scale
|
||||||
|
|
||||||
|
|
||||||
def convert_weights(model: nn.Module):
|
def convert_weights(model: nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue