From 50d5659d73b60bbfe64d692dbe70a9d9cc1725eb Mon Sep 17 00:00:00 2001 From: Jacky Date: Tue, 26 Jan 2021 17:38:33 +1100 Subject: [PATCH] Fix the instlalation requirements' --- clip/clip.py | 8 ++++---- setup.py | 5 +++-- tests/test_encode.py | 16 ++++++++++++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 87a70c9..7459ac4 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -9,8 +9,8 @@ from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm -from model import build_model -from simple_tokenizer import SimpleTokenizer as _Tokenizer +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() @@ -24,7 +24,7 @@ _MODELS = { def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) - + expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) @@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: while True: buffer = source.read(8192) if not buffer: diff --git a/setup.py b/setup.py index f71ca81..8cd66ce 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,8 @@ import os from setuptools import setup, find_packages core_req = ['ftfy', 'regex', 'tqdm', 'torch==1.7.1', 'torchvision'] -extra_requires={'cuda': ['cudatoolkit==11.0']} +extras_require={'cuda': ['cudatoolkit==11.0'], + 'dev': ['pytest']} setup( name='clip_by_openai', @@ -20,7 +21,7 @@ setup( packages=find_packages(exclude=["tests*"]), python_requires=">=3", install_requires=core_req, - extra_requires=extra_requires, + extras_require=extras_require, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", diff --git a/tests/test_encode.py b/tests/test_encode.py index 2aa783c..a6b3af9 100644 --- a/tests/test_encode.py +++ b/tests/test_encode.py @@ -1,4 +1,16 @@ import clip import torch -device = 'cpu' -model, preprocess = clip.load("ViT-B/32", device=device) +import torch +from PIL import Image +def test_simple_cpu(): + device = 'cpu' + model, preprocess = clip.load("ViT-B/32", device=device) + image = preprocess(Image.open('CLIP.png')).unsqueeze(0).to(device) + text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) + with torhc.no_grad(): + assert model.encode_image(image), "Encoding an image does not work" + assert model.encode_text(text), "Encoding text does not work" + logits_per_image, logits_per_text = model(image, text) + probs = logits_per_image.softmax(dim=-1).cpu().numpy() + print("Label probs:", probs) +