Add truncate_text option to tokenize
This makes it possible to run tokenize on texts that are longer than the number of tokens that fit the context length without having to try to guess how to cut in number of characters beforehand
This commit is contained in:
parent
cfcffb90e6
commit
a2f6647a4e
|
@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
|
@ -187,7 +187,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo
|
|||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
if truncate_text:
|
||||
tokens = tokens[:context_length]
|
||||
else:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue