correctly tokenizing SOT/EOT tokens (fixes #8)
This commit is contained in:
parent
c89e0c16de
commit
ebd0e35aac
File diff suppressed because one or more lines are too long
14
README.md
14
README.md
|
@ -41,7 +41,7 @@ with torch.no_grad():
|
|||
logits_per_image, logits_per_text = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
print("Label probs:", probs) # prints: [[0.9956 0.002144 0.002213]]
|
||||
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
|
||||
```
|
||||
|
||||
|
||||
|
@ -83,7 +83,7 @@ Given a batch of images and a batch of text tokens, returns two Tensors, contain
|
|||
|
||||
### Zero-Shot Prediction
|
||||
|
||||
The code below performs of zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
|
||||
The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -125,11 +125,11 @@ The output will look like the following (the exact numbers may be slightly diffe
|
|||
```
|
||||
Top predictions:
|
||||
|
||||
snake: 41.53%
|
||||
turtle: 24.04%
|
||||
sweet_pepper: 4.18%
|
||||
lizard: 3.92%
|
||||
leopard: 3.69%
|
||||
snake: 65.31%
|
||||
turtle: 12.29%
|
||||
sweet_pepper: 3.83%
|
||||
lizard: 1.88%
|
||||
crocodile: 1.75%
|
||||
```
|
||||
|
||||
Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
|
||||
|
|
4
clip.py
4
clip.py
|
@ -121,7 +121,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77):
|
|||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
all_tokens = [_tokenizer.encode(text + "<|endoftext|>") for text in texts]
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
|
|
Loading…
Reference in New Issue