add doc, rename to just "truncate", use eot_token
This commit is contained in:
parent
a2f6647a4e
commit
624f90e132
|
@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor:
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
|
@ -173,6 +173,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
|
|||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
truncate: bool
|
||||
Whether to truncate the text in case its encoding is longer than the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
|
@ -187,8 +190,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
|
|||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
if truncate_text:
|
||||
if truncate:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
else:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
|
Loading…
Reference in New Issue