From 624f90e1321bf2ecd4585ec4471865529a289981 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sun, 18 Jul 2021 20:14:47 -0700 Subject: [PATCH] add doc, rename to just "truncate", use eot_token --- clip/clip.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index cf32c93..cd6f244 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -173,6 +173,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te context_length : int The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] @@ -187,8 +190,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: - if truncate_text: + if truncate: tokens = tokens[:context_length] + tokens[-1] = eot_token else: raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens)