From a2737ac2644f46bb0a9785e4ddd3ad61aec3d468 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Mon, 19 Jul 2021 05:17:40 +0200 Subject: [PATCH] Add truncate option to tokenize (#126) * Add truncate_text option to tokenize This makes it possible to run tokenize on texts that are longer than the number of tokens that fit the context length without having to try to guess how to cut in number of characters beforehand * add doc, rename to just "truncate", use eot_token Co-authored-by: Jong Wook Kim --- clip/clip.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 55e1433..974ef06 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -180,7 +180,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -192,6 +192,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo context_length : int The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] @@ -206,7 +209,11 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result