extendable vocab
This commit is contained in:
parent
fd6c1443c2
commit
0848ed71e5
12
clip/clip.py
12
clip/clip.py
|
@ -10,10 +10,10 @@ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normal
|
|||
from tqdm import tqdm
|
||||
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from .simple_tokenizer import SimpleTokenizer
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
_tokenizer = SimpleTokenizer()
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
|
@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, tokenizer: SimpleTokenizer = _tokenizer) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
|
@ -180,9 +180,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo
|
|||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
sot_token = tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
|
|
|
@ -71,12 +71,20 @@ class SimpleTokenizer(object):
|
|||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.vocab = vocab
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def extend(self, tokens):
|
||||
self.vocab.extend(tokens)
|
||||
self.encoder = dict(zip(self.vocab, range(len(self.vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
for token in tokens:
|
||||
self.cache[token] = token
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
|
|
Loading…
Reference in New Issue