correctly tokenizing SOT/EOT tokens (fixes #8)

This commit is contained in:
Jong Wook Kim 2021-01-09 02:55:09 +09:00
parent c89e0c16de
commit ebd0e35aac
3 changed files with 47 additions and 344 deletions

File diff suppressed because one or more lines are too long

View File

@ -41,7 +41,7 @@ with torch.no_grad():
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints: [[0.9956 0.002144 0.002213]]
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
```
@ -83,7 +83,7 @@ Given a batch of images and a batch of text tokens, returns two Tensors, contain
### Zero-Shot Prediction
The code below performs of zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
```python
import os
@ -125,11 +125,11 @@ The output will look like the following (the exact numbers may be slightly diffe
```
Top predictions:
snake: 41.53%
turtle: 24.04%
sweet_pepper: 4.18%
lizard: 3.92%
leopard: 3.69%
snake: 65.31%
turtle: 12.29%
sweet_pepper: 3.83%
lizard: 1.88%
crocodile: 1.75%
```
Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.

View File

@ -121,7 +121,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77):
if isinstance(texts, str):
texts = [texts]
all_tokens = [_tokenizer.encode(text + "<|endoftext|>") for text in texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):