extendable vocab

This commit is contained in:
Nick Doiron 2021-03-07 01:30:35 -07:00
parent fd6c1443c2
commit 0848ed71e5
2 changed files with 14 additions and 6 deletions

View File

@ -10,10 +10,10 @@ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normal
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from .simple_tokenizer import SimpleTokenizer
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_tokenizer = SimpleTokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, tokenizer: SimpleTokenizer = _tokenizer) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
@ -180,9 +180,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
sot_token = tokenizer.encoder["<|startoftext|>"]
eot_token = tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):

View File

@ -71,12 +71,20 @@ class SimpleTokenizer(object):
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.vocab = vocab
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def extend(self, tokens):
self.vocab.extend(tokens)
self.encoder = dict(zip(self.vocab, range(len(self.vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
for token in tokens:
self.cache[token] = token
def bpe(self, token):
if token in self.cache:
return self.cache[token]