Add truncate_text option to tokenize

This makes it possible to run tokenize on texts that are longer than the number of tokens
that fit the context length without having to try to guess how to cut in number of 
characters beforehand
This commit is contained in:
Romain Beaumont 2021-07-08 17:34:16 +02:00 committed by GitHub
parent cfcffb90e6
commit a2f6647a4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
@ -187,7 +187,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
if truncate_text:
tokens = tokens[:context_length]
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result