correctly tokenizing SOT/EOT tokens (fixes #8)
This commit is contained in:
		
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										14
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								README.md
									
									
									
									
									
								
							@ -41,7 +41,7 @@ with torch.no_grad():
 | 
				
			|||||||
    logits_per_image, logits_per_text = model(image, text)
 | 
					    logits_per_image, logits_per_text = model(image, text)
 | 
				
			||||||
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
 | 
					    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
print("Label probs:", probs)  # prints: [[0.9956   0.002144 0.002213]]
 | 
					print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -83,7 +83,7 @@ Given a batch of images and a batch of text tokens, returns two Tensors, contain
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
### Zero-Shot Prediction
 | 
					### Zero-Shot Prediction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The code below performs of zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
 | 
					The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
@ -125,11 +125,11 @@ The output will look like the following (the exact numbers may be slightly diffe
 | 
				
			|||||||
```
 | 
					```
 | 
				
			||||||
Top predictions:
 | 
					Top predictions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
           snake: 41.53%
 | 
					           snake: 65.31%
 | 
				
			||||||
          turtle: 24.04%
 | 
					          turtle: 12.29%
 | 
				
			||||||
    sweet_pepper: 4.18%
 | 
					    sweet_pepper: 3.83%
 | 
				
			||||||
          lizard: 3.92%
 | 
					          lizard: 1.88%
 | 
				
			||||||
         leopard: 3.69%
 | 
					       crocodile: 1.75%
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
 | 
					Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								clip.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								clip.py
									
									
									
									
									
								
							@ -121,7 +121,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77):
 | 
				
			|||||||
    if isinstance(texts, str):
 | 
					    if isinstance(texts, str):
 | 
				
			||||||
        texts = [texts]
 | 
					        texts = [texts]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    all_tokens = [_tokenizer.encode(text + "<|endoftext|>") for text in texts]
 | 
					    sot_token = _tokenizer.encoder["<|startoftext|>"]
 | 
				
			||||||
 | 
					    eot_token = _tokenizer.encoder["<|endoftext|>"]
 | 
				
			||||||
 | 
					    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
 | 
				
			||||||
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
 | 
					    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for i, tokens in enumerate(all_tokens):
 | 
					    for i, tokens in enumerate(all_tokens):
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user