Fix the instlalation requirements'

This commit is contained in:
Jacky 2021-01-26 17:38:33 +11:00
parent e1b7bac3ec
commit 50d5659d73
3 changed files with 21 additions and 8 deletions

View File

@ -9,8 +9,8 @@ from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm from tqdm import tqdm
from model import build_model from .model import build_model
from simple_tokenizer import SimpleTokenizer as _Tokenizer from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"] __all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer() _tokenizer = _Tokenizer()
@ -24,7 +24,7 @@ _MODELS = {
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename) download_target = os.path.join(root, filename)
@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True: while True:
buffer = source.read(8192) buffer = source.read(8192)
if not buffer: if not buffer:

View File

@ -5,7 +5,8 @@ import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
core_req = ['ftfy', 'regex', 'tqdm', 'torch==1.7.1', 'torchvision'] core_req = ['ftfy', 'regex', 'tqdm', 'torch==1.7.1', 'torchvision']
extra_requires={'cuda': ['cudatoolkit==11.0']} extras_require={'cuda': ['cudatoolkit==11.0'],
'dev': ['pytest']}
setup( setup(
name='clip_by_openai', name='clip_by_openai',
@ -20,7 +21,7 @@ setup(
packages=find_packages(exclude=["tests*"]), packages=find_packages(exclude=["tests*"]),
python_requires=">=3", python_requires=">=3",
install_requires=core_req, install_requires=core_req,
extra_requires=extra_requires, extras_require=extras_require,
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@ -1,4 +1,16 @@
import clip import clip
import torch import torch
device = 'cpu' import torch
model, preprocess = clip.load("ViT-B/32", device=device) from PIL import Image
def test_simple_cpu():
device = 'cpu'
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open('CLIP.png')).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torhc.no_grad():
assert model.encode_image(image), "Encoding an image does not work"
assert model.encode_text(text), "Encoding text does not work"
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs)