2020-12-17 17:55:12 +01:00
import hashlib
import os
import urllib
import warnings
2021-08-09 08:43:22 +02:00
from typing import Any , Union , List
2021-11-09 07:57:26 +01:00
from pkg_resources import packaging
2020-12-17 17:55:12 +01:00
import torch
from PIL import Image
from torchvision . transforms import Compose , Resize , CenterCrop , ToTensor , Normalize
from tqdm import tqdm
2021-01-29 19:05:01 +01:00
from . model import build_model
from . simple_tokenizer import SimpleTokenizer as _Tokenizer
2020-12-17 17:55:12 +01:00
2021-07-19 03:45:21 +02:00
try :
from torchvision . transforms import InterpolationMode
BICUBIC = InterpolationMode . BICUBIC
except ImportError :
BICUBIC = Image . BICUBIC
2021-11-09 07:57:26 +01:00
if packaging . version . parse ( torch . __version__ ) < packaging . version . parse ( " 1.7.1 " ) :
2021-07-19 03:45:21 +02:00
warnings . warn ( " PyTorch version 1.7.1 or higher is recommended " )
2020-12-17 17:55:12 +01:00
__all__ = [ " available_models " , " load " , " tokenize " ]
_tokenizer = _Tokenizer ( )
_MODELS = {
2021-01-13 00:35:50 +01:00
" RN50 " : " https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt " ,
2021-03-04 18:30:39 +01:00
" RN101 " : " https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt " ,
" RN50x4 " : " https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt " ,
2021-07-19 13:46:29 +02:00
" RN50x16 " : " https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt " ,
2020-12-17 17:55:12 +01:00
" ViT-B/32 " : " https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt " ,
2021-07-19 13:46:29 +02:00
" ViT-B/16 " : " https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt " ,
2020-12-17 17:55:12 +01:00
}
2021-08-09 08:43:22 +02:00
def _download ( url : str , root : str ) :
2020-12-17 17:55:12 +01:00
os . makedirs ( root , exist_ok = True )
filename = os . path . basename ( url )
2021-01-29 19:05:01 +01:00
2020-12-17 17:55:12 +01:00
expected_sha256 = url . split ( " / " ) [ - 2 ]
download_target = os . path . join ( root , filename )
if os . path . exists ( download_target ) and not os . path . isfile ( download_target ) :
raise RuntimeError ( f " { download_target } exists and is not a regular file " )
if os . path . isfile ( download_target ) :
if hashlib . sha256 ( open ( download_target , " rb " ) . read ( ) ) . hexdigest ( ) == expected_sha256 :
return download_target
else :
warnings . warn ( f " { download_target } exists, but the SHA256 checksum does not match; re-downloading the file " )
with urllib . request . urlopen ( url ) as source , open ( download_target , " wb " ) as output :
2021-08-09 08:20:38 +02:00
with tqdm ( total = int ( source . info ( ) . get ( " Content-Length " ) ) , ncols = 80 , unit = ' iB ' , unit_scale = True , unit_divisor = 1024 ) as loop :
2020-12-17 17:55:12 +01:00
while True :
buffer = source . read ( 8192 )
if not buffer :
break
output . write ( buffer )
loop . update ( len ( buffer ) )
if hashlib . sha256 ( open ( download_target , " rb " ) . read ( ) ) . hexdigest ( ) != expected_sha256 :
raise RuntimeError ( f " Model has been downloaded but the SHA256 checksum does not not match " )
return download_target
2021-09-24 03:42:20 +02:00
def _convert_image_to_rgb ( image ) :
return image . convert ( " RGB " )
2021-02-16 12:19:42 +01:00
def _transform ( n_px ) :
return Compose ( [
2021-07-19 03:45:21 +02:00
Resize ( n_px , interpolation = BICUBIC ) ,
2021-01-13 00:35:50 +01:00
CenterCrop ( n_px ) ,
2021-09-24 03:42:20 +02:00
_convert_image_to_rgb ,
2021-01-13 00:35:50 +01:00
ToTensor ( ) ,
Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) ) ,
] )
2021-02-16 12:19:42 +01:00
def available_models ( ) - > List [ str ] :
""" Returns the names of available CLIP models """
return list ( _MODELS . keys ( ) )
2021-08-09 08:43:22 +02:00
def load ( name : str , device : Union [ str , torch . device ] = " cuda " if torch . cuda . is_available ( ) else " cpu " , jit : bool = False , download_root : str = None ) :
2021-02-16 12:19:42 +01:00
""" Load a CLIP model
Parameters
- - - - - - - - - -
name : str
A model name listed by ` clip . available_models ( ) ` , or the path to a model checkpoint containing the state_dict
device : Union [ str , torch . device ]
The device to put the loaded model
jit : bool
2021-07-19 03:45:21 +02:00
Whether to load the optimized JIT model or more hackable non - JIT model ( default ) .
2021-02-16 12:19:42 +01:00
2021-08-09 08:43:22 +02:00
download_root : str
path to download the model files ; by default , it uses " ~/.cache/clip "
2021-02-16 12:19:42 +01:00
Returns
- - - - - - -
model : torch . nn . Module
The CLIP model
preprocess : Callable [ [ PIL . Image ] , torch . Tensor ]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS :
2021-08-09 08:43:22 +02:00
model_path = _download ( _MODELS [ name ] , download_root or os . path . expanduser ( " ~/.cache/clip " ) )
2021-02-16 12:19:42 +01:00
elif os . path . isfile ( name ) :
model_path = name
else :
raise RuntimeError ( f " Model { name } not found; available models = { available_models ( ) } " )
try :
# loading JIT archive
model = torch . jit . load ( model_path , map_location = device if jit else " cpu " ) . eval ( )
state_dict = None
except RuntimeError :
# loading saved state dict
if jit :
warnings . warn ( f " File { model_path } is not a JIT archive. Loading as a state dict instead " )
jit = False
state_dict = torch . load ( model_path , map_location = " cpu " )
2021-01-13 00:35:50 +01:00
if not jit :
2021-02-16 12:19:42 +01:00
model = build_model ( state_dict or model . state_dict ( ) ) . to ( device )
2021-01-29 19:05:01 +01:00
if str ( device ) == " cpu " :
model . float ( )
2021-02-16 12:19:42 +01:00
return model , _transform ( model . visual . input_resolution )
2021-01-13 00:35:50 +01:00
2020-12-17 17:55:12 +01:00
# patch the device names
device_holder = torch . jit . trace ( lambda : torch . ones ( [ ] ) . to ( torch . device ( device ) ) , example_inputs = [ ] )
device_node = [ n for n in device_holder . graph . findAllNodes ( " prim::Constant " ) if " Device " in repr ( n ) ] [ - 1 ]
def patch_device ( module ) :
2021-07-19 03:45:21 +02:00
try :
graphs = [ module . graph ] if hasattr ( module , " graph " ) else [ ]
except RuntimeError :
graphs = [ ]
2020-12-17 17:55:12 +01:00
if hasattr ( module , " forward1 " ) :
graphs . append ( module . forward1 . graph )
for graph in graphs :
for node in graph . findAllNodes ( " prim::Constant " ) :
if " value " in node . attributeNames ( ) and str ( node [ " value " ] ) . startswith ( " cuda " ) :
node . copyAttributes ( device_node )
model . apply ( patch_device )
patch_device ( model . encode_image )
patch_device ( model . encode_text )
# patch dtype to float32 on CPU
2021-01-29 19:05:01 +01:00
if str ( device ) == " cpu " :
2020-12-17 17:55:12 +01:00
float_holder = torch . jit . trace ( lambda : torch . ones ( [ ] ) . float ( ) , example_inputs = [ ] )
float_input = list ( float_holder . graph . findNode ( " aten::to " ) . inputs ( ) ) [ 1 ]
float_node = float_input . node ( )
def patch_float ( module ) :
2021-07-19 03:45:21 +02:00
try :
graphs = [ module . graph ] if hasattr ( module , " graph " ) else [ ]
except RuntimeError :
graphs = [ ]
2020-12-17 17:55:12 +01:00
if hasattr ( module , " forward1 " ) :
graphs . append ( module . forward1 . graph )
for graph in graphs :
for node in graph . findAllNodes ( " aten::to " ) :
inputs = list ( node . inputs ( ) )
for i in [ 1 , 2 ] : # dtype can be the second or third argument to aten::to()
if inputs [ i ] . node ( ) [ " value " ] == 5 :
inputs [ i ] . node ( ) . copyAttributes ( float_node )
model . apply ( patch_float )
patch_float ( model . encode_image )
patch_float ( model . encode_text )
model . float ( )
2021-02-16 12:19:42 +01:00
return model , _transform ( model . input_resolution . item ( ) )
2021-07-19 05:17:40 +02:00
def tokenize ( texts : Union [ str , List [ str ] ] , context_length : int = 77 , truncate : bool = False ) - > torch . LongTensor :
2021-02-16 12:19:42 +01:00
"""
Returns the tokenized representation of given input string ( s )
Parameters
- - - - - - - - - -
texts : Union [ str , List [ str ] ]
An input string or a list of input strings to tokenize
2020-12-17 17:55:12 +01:00
2021-02-16 12:19:42 +01:00
context_length : int
The context length to use ; all CLIP models use 77 as the context length
2020-12-17 17:55:12 +01:00
2021-07-19 05:17:40 +02:00
truncate : bool
Whether to truncate the text in case its encoding is longer than the context length
2021-02-16 12:19:42 +01:00
Returns
- - - - - - -
A two - dimensional tensor containing the resulting tokens , shape = [ number of input strings , context_length ]
"""
2020-12-17 17:55:12 +01:00
if isinstance ( texts , str ) :
texts = [ texts ]
2021-01-08 18:55:09 +01:00
sot_token = _tokenizer . encoder [ " <|startoftext|> " ]
eot_token = _tokenizer . encoder [ " <|endoftext|> " ]
all_tokens = [ [ sot_token ] + _tokenizer . encode ( text ) + [ eot_token ] for text in texts ]
2020-12-17 17:55:12 +01:00
result = torch . zeros ( len ( all_tokens ) , context_length , dtype = torch . long )
for i , tokens in enumerate ( all_tokens ) :
if len ( tokens ) > context_length :
2021-07-19 05:17:40 +02:00
if truncate :
tokens = tokens [ : context_length ]
tokens [ - 1 ] = eot_token
else :
raise RuntimeError ( f " Input { texts [ i ] } is too long for context length { context_length } " )
2020-12-17 17:55:12 +01:00
result [ i , : len ( tokens ) ] = torch . tensor ( tokens )
return result