2020-12-17 17:55:12 +01:00
import hashlib
import os
import urllib
import warnings
from typing import Union , List
import torch
from PIL import Image
from torchvision . transforms import Compose , Resize , CenterCrop , ToTensor , Normalize
from tqdm import tqdm
from simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = [ " available_models " , " load " , " tokenize " ]
_tokenizer = _Tokenizer ( )
_MODELS = {
" ViT-B/32 " : " https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt " ,
}
def _download ( url : str , root : str = os . path . expanduser ( " ~/.cache/clip " ) ) :
os . makedirs ( root , exist_ok = True )
filename = os . path . basename ( url )
expected_sha256 = url . split ( " / " ) [ - 2 ]
download_target = os . path . join ( root , filename )
if os . path . exists ( download_target ) and not os . path . isfile ( download_target ) :
raise RuntimeError ( f " { download_target } exists and is not a regular file " )
if os . path . isfile ( download_target ) :
if hashlib . sha256 ( open ( download_target , " rb " ) . read ( ) ) . hexdigest ( ) == expected_sha256 :
return download_target
else :
warnings . warn ( f " { download_target } exists, but the SHA256 checksum does not match; re-downloading the file " )
with urllib . request . urlopen ( url ) as source , open ( download_target , " wb " ) as output :
with tqdm ( total = int ( source . info ( ) . get ( " Content-Length " ) ) , ncols = 80 ) as loop :
while True :
buffer = source . read ( 8192 )
if not buffer :
break
output . write ( buffer )
loop . update ( len ( buffer ) )
if hashlib . sha256 ( open ( download_target , " rb " ) . read ( ) ) . hexdigest ( ) != expected_sha256 :
raise RuntimeError ( f " Model has been downloaded but the SHA256 checksum does not not match " )
return download_target
def available_models ( ) :
return list ( _MODELS . keys ( ) )
def load ( name : str , device : Union [ str , torch . device ] = " cuda " if torch . cuda . is_available ( ) else " cpu " ) :
if name not in _MODELS :
raise RuntimeError ( f " Model { name } not found; available models = { available_models ( ) } " )
model_path = _download ( _MODELS [ name ] )
model = torch . jit . load ( model_path , map_location = device ) . eval ( )
n_px = model . input_resolution . item ( )
# patch the device names
device_holder = torch . jit . trace ( lambda : torch . ones ( [ ] ) . to ( torch . device ( device ) ) , example_inputs = [ ] )
device_node = [ n for n in device_holder . graph . findAllNodes ( " prim::Constant " ) if " Device " in repr ( n ) ] [ - 1 ]
def patch_device ( module ) :
graphs = [ module . graph ] if hasattr ( module , " graph " ) else [ ]
if hasattr ( module , " forward1 " ) :
graphs . append ( module . forward1 . graph )
for graph in graphs :
for node in graph . findAllNodes ( " prim::Constant " ) :
if " value " in node . attributeNames ( ) and str ( node [ " value " ] ) . startswith ( " cuda " ) :
node . copyAttributes ( device_node )
model . apply ( patch_device )
patch_device ( model . encode_image )
patch_device ( model . encode_text )
# patch dtype to float32 on CPU
if device == " cpu " :
float_holder = torch . jit . trace ( lambda : torch . ones ( [ ] ) . float ( ) , example_inputs = [ ] )
float_input = list ( float_holder . graph . findNode ( " aten::to " ) . inputs ( ) ) [ 1 ]
float_node = float_input . node ( )
def patch_float ( module ) :
graphs = [ module . graph ] if hasattr ( module , " graph " ) else [ ]
if hasattr ( module , " forward1 " ) :
graphs . append ( module . forward1 . graph )
for graph in graphs :
for node in graph . findAllNodes ( " aten::to " ) :
inputs = list ( node . inputs ( ) )
for i in [ 1 , 2 ] : # dtype can be the second or third argument to aten::to()
if inputs [ i ] . node ( ) [ " value " ] == 5 :
inputs [ i ] . node ( ) . copyAttributes ( float_node )
model . apply ( patch_float )
patch_float ( model . encode_image )
patch_float ( model . encode_text )
model . float ( )
transform = Compose ( [
Resize ( n_px , interpolation = Image . BICUBIC ) ,
CenterCrop ( n_px ) ,
lambda image : image . convert ( " RGB " ) ,
ToTensor ( ) ,
Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) ) ,
] )
return model , transform
def tokenize ( texts : Union [ str , List [ str ] ] , context_length : int = 77 ) :
if isinstance ( texts , str ) :
texts = [ texts ]
2021-01-08 18:55:09 +01:00
sot_token = _tokenizer . encoder [ " <|startoftext|> " ]
eot_token = _tokenizer . encoder [ " <|endoftext|> " ]
all_tokens = [ [ sot_token ] + _tokenizer . encode ( text ) + [ eot_token ] for text in texts ]
2020-12-17 17:55:12 +01:00
result = torch . zeros ( len ( all_tokens ) , context_length , dtype = torch . long )
for i , tokens in enumerate ( all_tokens ) :
if len ( tokens ) > context_length :
raise RuntimeError ( f " Input { texts [ i ] } is too long for context length { context_length } " )
result [ i , : len ( tokens ) ] = torch . tensor ( tokens )
return result