add doc, rename to just "truncate", use eot_token
This commit is contained in:
parent
a2f6647a4e
commit
624f90e132
@ -161,7 +161,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||||||
return model, _transform(model.input_resolution.item())
|
return model, _transform(model.input_resolution.item())
|
||||||
|
|
||||||
|
|
||||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_text = False) -> torch.LongTensor:
|
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
||||||
"""
|
"""
|
||||||
Returns the tokenized representation of given input string(s)
|
Returns the tokenized representation of given input string(s)
|
||||||
|
|
||||||
@ -173,6 +173,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
|
|||||||
context_length : int
|
context_length : int
|
||||||
The context length to use; all CLIP models use 77 as the context length
|
The context length to use; all CLIP models use 77 as the context length
|
||||||
|
|
||||||
|
truncate: bool
|
||||||
|
Whether to truncate the text in case its encoding is longer than the context length
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||||
@ -187,8 +190,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate_te
|
|||||||
|
|
||||||
for i, tokens in enumerate(all_tokens):
|
for i, tokens in enumerate(all_tokens):
|
||||||
if len(tokens) > context_length:
|
if len(tokens) > context_length:
|
||||||
if truncate_text:
|
if truncate:
|
||||||
tokens = tokens[:context_length]
|
tokens = tokens[:context_length]
|
||||||
|
tokens[-1] = eot_token
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user