Make the repo installable as a package (#26)

This commit is contained in:
boba_and_beer 2021-01-30 05:05:01 +11:00 committed by GitHub
parent 578a1d3e2e
commit 3bee28119e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 5 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info
.pytest_cache
.ipynb_checkpoints
thumbs.db
.DS_Store
.idea

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
include clip/bpe_simple_vocab_16e6.txt.gz

1
clip/__init__.py Normal file
View File

@ -0,0 +1 @@
from .clip import *

View File

@ -9,8 +9,8 @@ from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from model import build_model
from simple_tokenizer import SimpleTokenizer as _Tokenizer
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
@ -24,7 +24,7 @@ _MODELS = {
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
@ -75,6 +75,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
if not jit:
model = build_model(model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, transform
# patch the device names
@ -96,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == "cpu":
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
ftfy
regex
tqdm
torch>=1.7.1,<1.7.2
torchvision==0.8.2

21
setup.py Normal file
View File

@ -0,0 +1,21 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="clip",
py_modules=["clip"],
version="1.0",
description="",
author="OpenAI",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
include_package_data=True,
extras_require={'dev': ['pytest']},
)

25
tests/test_consistency.py Normal file
View File

@ -0,0 +1,25 @@
import numpy as np
import pytest
import torch
from PIL import Image
import clip
@pytest.mark.parametrize('model_name', clip.available_models())
def test_consistency(model_name):
device = "cpu"
jit_model, transform = clip.load(model_name, device=device)
py_model, _ = clip.load(model_name, device=device, jit=False)
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
logits_per_image, _ = jit_model(image, text)
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
logits_per_image, _ = py_model(image, text)
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)