From 0848ed71e5bb5b4eaaa4048364001877630345dc Mon Sep 17 00:00:00 2001 From: Nick Doiron Date: Sun, 7 Mar 2021 01:30:35 -0700 Subject: [PATCH] extendable vocab --- clip/clip.py | 12 ++++++------ clip/simple_tokenizer.py | 8 ++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 76f241b..bb85b63 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -10,10 +10,10 @@ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normal from tqdm import tqdm from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from .simple_tokenizer import SimpleTokenizer __all__ = ["available_models", "load", "tokenize"] -_tokenizer = _Tokenizer() +_tokenizer = SimpleTokenizer() _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", @@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, tokenizer: SimpleTokenizer = _tokenizer) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -180,9 +180,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo if isinstance(texts, str): texts = [texts] - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + sot_token = tokenizer.encoder["<|startoftext|>"] + eot_token = tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py index 0a66286..d047c1c 100644 --- a/clip/simple_tokenizer.py +++ b/clip/simple_tokenizer.py @@ -71,12 +71,20 @@ class SimpleTokenizer(object): for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.vocab = vocab self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + def extend(self, tokens): + self.vocab.extend(tokens) + self.encoder = dict(zip(self.vocab, range(len(self.vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + for token in tokens: + self.cache[token] = token + def bpe(self, token): if token in self.cache: return self.cache[token]