diff --git a/clip/clip.py b/clip/clip.py index 76f241b..cf32c93 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -187,7 +187,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + if truncate_text: + tokens = tokens[:context_length] + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result