diff --git a/demo.py b/demo.py index 06620d0..772fc91 100644 --- a/demo.py +++ b/demo.py @@ -6,14 +6,16 @@ import gradio as gr # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" -model, transform = clip.load("ViT-B/32", device=device) +model, preprocess = clip.load('ViT-B/32', device) # Download the dataset cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) -def classify(img): - image = transform(img).unsqueeze(0).to(device) - text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) +def classify(img, user_text): + image = preprocess(img).unsqueeze(0).to(device) + user_texts = user_text.split(",") + text_sources = cifar100.classes + user_texts + text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in text_sources]).to(device) # Calculate features with torch.no_grad(): @@ -26,14 +28,16 @@ def classify(img): similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(5) - text="" - # Print the result + result = {} for value, index in zip(values, indices): - text+=f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%\n" - return text + result[text_sources[index]] = value.item() + return result -inputs = gr.inputs.Image(type='pil', label="Original Image") -outputs = gr.outputs.Textbox(type="str", label="Text Output") +inputs = [ + gr.inputs.Image(type='pil', label="Original Image"), + gr.inputs.Textbox(lines=1) +] +outputs = gr.outputs.Label(type="confidences",num_top_classes=5) title = "CLIP" description = "CLIP demo"