Make the repo installable as a package (#26)

This commit is contained in:
boba_and_beer 2021-01-30 05:05:01 +11:00 committed by GitHub
parent 578a1d3e2e
commit 3bee28119e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 5 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info
.pytest_cache
.ipynb_checkpoints
thumbs.db
.DS_Store
.idea

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
include clip/bpe_simple_vocab_16e6.txt.gz

1
clip/__init__.py Normal file
View File

@ -0,0 +1 @@
from .clip import *

View File

@ -9,8 +9,8 @@ from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm from tqdm import tqdm
from model import build_model from .model import build_model
from simple_tokenizer import SimpleTokenizer as _Tokenizer from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"] __all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer() _tokenizer = _Tokenizer()
@ -24,7 +24,7 @@ _MODELS = {
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename) download_target = os.path.join(root, filename)
@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True: while True:
buffer = source.read(8192) buffer = source.read(8192)
if not buffer: if not buffer:
@ -75,6 +75,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
if not jit: if not jit:
model = build_model(model.state_dict()).to(device) model = build_model(model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, transform return model, transform
# patch the device names # patch the device names
@ -96,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
patch_device(model.encode_text) patch_device(model.encode_text)
# patch dtype to float32 on CPU # patch dtype to float32 on CPU
if device == "cpu": if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node() float_node = float_input.node()

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
ftfy
regex
tqdm
torch>=1.7.1,<1.7.2
torchvision==0.8.2

21
setup.py Normal file
View File

@ -0,0 +1,21 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="clip",
py_modules=["clip"],
version="1.0",
description="",
author="OpenAI",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
include_package_data=True,
extras_require={'dev': ['pytest']},
)

25
tests/test_consistency.py Normal file
View File

@ -0,0 +1,25 @@
import numpy as np
import pytest
import torch
from PIL import Image
import clip
@pytest.mark.parametrize('model_name', clip.available_models())
def test_consistency(model_name):
device = "cpu"
jit_model, transform = clip.load(model_name, device=device)
py_model, _ = clip.load(model_name, device=device, jit=False)
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
logits_per_image, _ = jit_model(image, text)
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
logits_per_image, _ = py_model(image, text)
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)