Make the repo installable as a package (#26)
This commit is contained in:
parent
578a1d3e2e
commit
3bee28119e
|
@ -0,0 +1,10 @@
|
|||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info
|
||||
.pytest_cache
|
||||
.ipynb_checkpoints
|
||||
|
||||
thumbs.db
|
||||
.DS_Store
|
||||
.idea
|
|
@ -0,0 +1 @@
|
|||
include clip/bpe_simple_vocab_16e6.txt.gz
|
|
@ -0,0 +1 @@
|
|||
from .clip import *
|
|
@ -9,8 +9,8 @@ from PIL import Image
|
|||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import build_model
|
||||
from simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
@ -24,7 +24,7 @@ _MODELS = {
|
|||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
|
@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
|||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
|
@ -75,6 +75,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
|
||||
if not jit:
|
||||
model = build_model(model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, transform
|
||||
|
||||
# patch the device names
|
||||
|
@ -96,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if device == "cpu":
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
|
@ -0,0 +1,5 @@
|
|||
ftfy
|
||||
regex
|
||||
tqdm
|
||||
torch>=1.7.1,<1.7.2
|
||||
torchvision==0.8.2
|
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="clip",
|
||||
py_modules=["clip"],
|
||||
version="1.0",
|
||||
description="",
|
||||
author="OpenAI",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
],
|
||||
include_package_data=True,
|
||||
extras_require={'dev': ['pytest']},
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import clip
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', clip.available_models())
|
||||
def test_consistency(model_name):
|
||||
device = "cpu"
|
||||
jit_model, transform = clip.load(model_name, device=device)
|
||||
py_model, _ = clip.load(model_name, device=device, jit=False)
|
||||
|
||||
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_per_image, _ = jit_model(image, text)
|
||||
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
logits_per_image, _ = py_model(image, text)
|
||||
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)
|
Loading…
Reference in New Issue