From a2f6647a4edb63dbc1693b901d4fbacecd00554c Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Thu, 8 Jul 2021 17:34:16 +0200 Subject: [PATCH] Add truncate_text option to tokenize This makes it possible to run tokenize on texts that are longer than the number of tokens that fit the context length without having to try to guess how to cut in number of characters beforehand --- clip/clip.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 76f241b..cf32c93 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -187,7 +187,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + if truncate_text: + tokens = tokens[:context_length] + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result