Make the repo installable as a package (#26)
This commit is contained in:
parent
578a1d3e2e
commit
3bee28119e
|
@ -0,0 +1,10 @@
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info
|
||||||
|
.pytest_cache
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
thumbs.db
|
||||||
|
.DS_Store
|
||||||
|
.idea
|
|
@ -0,0 +1 @@
|
||||||
|
include clip/bpe_simple_vocab_16e6.txt.gz
|
|
@ -0,0 +1 @@
|
||||||
|
from .clip import *
|
|
@ -9,8 +9,8 @@ from PIL import Image
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from model import build_model
|
from .model import build_model
|
||||||
from simple_tokenizer import SimpleTokenizer as _Tokenizer
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||||
|
|
||||||
__all__ = ["available_models", "load", "tokenize"]
|
__all__ = ["available_models", "load", "tokenize"]
|
||||||
_tokenizer = _Tokenizer()
|
_tokenizer = _Tokenizer()
|
||||||
|
@ -24,7 +24,7 @@ _MODELS = {
|
||||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
filename = os.path.basename(url)
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
expected_sha256 = url.split("/")[-2]
|
||||||
download_target = os.path.join(root, filename)
|
download_target = os.path.join(root, filename)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
|
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
|
||||||
while True:
|
while True:
|
||||||
buffer = source.read(8192)
|
buffer = source.read(8192)
|
||||||
if not buffer:
|
if not buffer:
|
||||||
|
@ -75,6 +75,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
|
|
||||||
if not jit:
|
if not jit:
|
||||||
model = build_model(model.state_dict()).to(device)
|
model = build_model(model.state_dict()).to(device)
|
||||||
|
if str(device) == "cpu":
|
||||||
|
model.float()
|
||||||
return model, transform
|
return model, transform
|
||||||
|
|
||||||
# patch the device names
|
# patch the device names
|
||||||
|
@ -96,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
patch_device(model.encode_text)
|
patch_device(model.encode_text)
|
||||||
|
|
||||||
# patch dtype to float32 on CPU
|
# patch dtype to float32 on CPU
|
||||||
if device == "cpu":
|
if str(device) == "cpu":
|
||||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||||
float_node = float_input.node()
|
float_node = float_input.node()
|
|
@ -0,0 +1,5 @@
|
||||||
|
ftfy
|
||||||
|
regex
|
||||||
|
tqdm
|
||||||
|
torch>=1.7.1,<1.7.2
|
||||||
|
torchvision==0.8.2
|
|
@ -0,0 +1,21 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="clip",
|
||||||
|
py_modules=["clip"],
|
||||||
|
version="1.0",
|
||||||
|
description="",
|
||||||
|
author="OpenAI",
|
||||||
|
packages=find_packages(exclude=["tests*"]),
|
||||||
|
install_requires=[
|
||||||
|
str(r)
|
||||||
|
for r in pkg_resources.parse_requirements(
|
||||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
|
)
|
||||||
|
],
|
||||||
|
include_package_data=True,
|
||||||
|
extras_require={'dev': ['pytest']},
|
||||||
|
)
|
|
@ -0,0 +1,25 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import clip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model_name', clip.available_models())
|
||||||
|
def test_consistency(model_name):
|
||||||
|
device = "cpu"
|
||||||
|
jit_model, transform = clip.load(model_name, device=device)
|
||||||
|
py_model, _ = clip.load(model_name, device=device, jit=False)
|
||||||
|
|
||||||
|
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||||
|
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_per_image, _ = jit_model(image, text)
|
||||||
|
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||||
|
|
||||||
|
logits_per_image, _ = py_model(image, text)
|
||||||
|
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||||
|
|
||||||
|
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)
|
Loading…
Reference in New Issue