add doc, rename to just "truncate", use eot_token

This commit is contained in:
Jong Wook Kim 2021-07-18 20:14:47 -07:00
parent a2f6647a4e
commit 624f90e132
1 changed files with 6 additions and 2 deletions

View File

@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
@ -173,6 +173,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
@ -187,8 +190,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate_text:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)