CLIP_on_Tesla_K20Xm/notebooks/Interacting_with_CLIP.ipynb

854 lines
3.1 MiB
Plaintext
Raw Normal View History

2020-12-17 17:55:12 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Interacting with CLIP.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
2020-12-17 17:55:12 +01:00
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "YPHN7PJgKOzb"
},
"source": [
"# Interacting with CLIP\n",
"\n",
"This is a self-contained notebook that shows how to download and run CLIP models, calculate the similarity between arbitrary image and text inputs, and perform zero-shot image classifications."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "53N4k0pj_9qL"
},
"source": [
"# Preparation for Colab\n",
"\n",
"Make sure you're running a GPU runtime; if not, select \"GPU\" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will print the CUDA version of the runtime if it has a GPU, and install PyTorch 1.7.1."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0BpdJkdBssk9",
"outputId": "f26d1899-1e21-427a-f730-9451be9a2572"
2020-12-17 17:55:12 +01:00
},
"source": [
"import subprocess\n",
"\n",
"CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
"print(\"CUDA version:\", CUDA_version)\n",
"\n",
"if CUDA_version == \"10.0\":\n",
" torch_version_suffix = \"+cu100\"\n",
"elif CUDA_version == \"10.1\":\n",
" torch_version_suffix = \"+cu101\"\n",
"elif CUDA_version == \"10.2\":\n",
" torch_version_suffix = \"\"\n",
"else:\n",
" torch_version_suffix = \"+cu110\""
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"CUDA version: 10.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RBVr18E5tse8",
"outputId": "d4647553-46fa-4cc5-8a41-b152abc0b4d2"
2020-12-17 17:55:12 +01:00
},
"source": [
"! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
"Requirement already satisfied: torch==1.7.1+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.1+cu101)\n",
"Requirement already satisfied: torchvision==0.8.2+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.2+cu101)\n",
"Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8)\n",
"Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (1.19.4)\n",
2020-12-17 17:55:12 +01:00
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (0.8)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (3.7.4.3)\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.2+cu101) (7.0.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)\n"
2020-12-17 17:55:12 +01:00
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "C1hkDT38hSaP",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2fb746de-c8f1-4e5f-d611-e151934c0994"
2020-12-17 17:55:12 +01:00
},
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"print(\"Torch version:\", torch.__version__)"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Torch version: 1.7.1+cu101\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eFxgLV5HAEEw"
},
"source": [
"# Downloading the model\n",
"\n",
"CLIP models are distributed as TorchScript modules."
]
},
{
"cell_type": "code",
"metadata": {
"id": "uLFS29hnhlY4"
},
"source": [
"MODELS = {\n",
2021-03-08 03:58:54 +01:00
" \"RN50\": \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n",
" \"RN101\": \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n",
" \"RN50x4\": \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n",
" \"ViT-B/32\": \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\", \n",
2020-12-17 17:55:12 +01:00
"}"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cboKZocQlSYX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c3b6fa02-3b17-4a54-d3c4-8970cba3de9d"
2020-12-17 17:55:12 +01:00
},
"source": [
"! wget {MODELS[\"ViT-B/32\"]} -O model.pt"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"--2021-01-08 17:41:09-- https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\n",
2020-12-17 17:55:12 +01:00
"Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.246.13, 2620:1ec:bdf::13\n",
"Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.246.13|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 353976522 (338M) [application/octet-stream]\n",
"Saving to: model.pt\n",
"\n",
"model.pt 100%[===================>] 337.58M 125MB/s in 2.7s \n",
2020-12-17 17:55:12 +01:00
"\n",
"2021-01-08 17:41:12 (125 MB/s) - model.pt saved [353976522/353976522]\n",
2020-12-17 17:55:12 +01:00
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IBRVTY9lbGm8",
"outputId": "768bce5b-e807-43fd-e559-c4f120f16dc7"
2020-12-17 17:55:12 +01:00
},
"source": [
"model = torch.jit.load(\"model.pt\").cuda().eval()\n",
"input_resolution = model.input_resolution.item()\n",
"context_length = model.context_length.item()\n",
"vocab_size = model.vocab_size.item()\n",
"\n",
"print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
"print(\"Input resolution:\", input_resolution)\n",
"print(\"Context length:\", context_length)\n",
"print(\"Vocab size:\", vocab_size)"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Model parameters: 151,277,313\n",
"Input resolution: 224\n",
"Context length: 77\n",
"Vocab size: 49408\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "21slhZGCqANb"
},
"source": [
"# Image Preprocessing\n",
"\n",
"We resize the input images and center-crop them to conform with the image resolution that the model expects. Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "d6cpiIFHp9N6"
},
"source": [
"from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n",
"from PIL import Image\n",
"\n",
"preprocess = Compose([\n",
" Resize(input_resolution, interpolation=Image.BICUBIC),\n",
" CenterCrop(input_resolution),\n",
" ToTensor()\n",
"])\n",
"\n",
"image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()\n",
"image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xwSB5jZki3Cj"
},
"source": [
"# Text Preprocessing\n",
"\n",
"We use a case-insensitive tokenizer. The tokenizer code is hidden in the second cell below"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qGom156-i2kL",
"outputId": "b61049f3-0ce5-4e95-c477-9c62b2fabca4"
2020-12-17 17:55:12 +01:00
},
"source": [
"! pip install ftfy regex\n",
"! wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8)\n",
"Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)\n",
"--2021-01-08 17:41:19-- https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz\n",
2020-12-17 17:55:12 +01:00
"Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.246.13, 2620:1ec:bdf::13\n",
"Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.246.13|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1356917 (1.3M) [application/octet-stream]\n",
"Saving to: bpe_simple_vocab_16e6.txt.gz\n",
"\n",
"bpe_simple_vocab_16 100%[===================>] 1.29M --.-KB/s in 0.01s \n",
2020-12-17 17:55:12 +01:00
"\n",
"2021-01-08 17:41:19 (93.6 MB/s) - bpe_simple_vocab_16e6.txt.gz saved [1356917/1356917]\n",
2020-12-17 17:55:12 +01:00
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "toGtcd-Ji_MD",
"cellView": "form"
},
"source": [
"#@title\n",
"\n",
"import gzip\n",
"import html\n",
"import os\n",
"from functools import lru_cache\n",
"\n",
"import ftfy\n",
"import regex as re\n",
"\n",
"\n",
"@lru_cache()\n",
"def bytes_to_unicode():\n",
" \"\"\"\n",
" Returns list of utf-8 byte and a corresponding list of unicode strings.\n",
" The reversible bpe codes work on unicode strings.\n",
" This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n",
" When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n",
" This is a signficant percentage of your normal, say, 32K bpe vocab.\n",
" To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n",
" And avoids mapping to whitespace/control characters the bpe code barfs on.\n",
" \"\"\"\n",
" bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n",
" cs = bs[:]\n",
" n = 0\n",
" for b in range(2**8):\n",
" if b not in bs:\n",
" bs.append(b)\n",
" cs.append(2**8+n)\n",
" n += 1\n",
" cs = [chr(n) for n in cs]\n",
" return dict(zip(bs, cs))\n",
"\n",
"\n",
"def get_pairs(word):\n",
" \"\"\"Return set of symbol pairs in a word.\n",
" Word is represented as tuple of symbols (symbols being variable-length strings).\n",
" \"\"\"\n",
" pairs = set()\n",
" prev_char = word[0]\n",
" for char in word[1:]:\n",
" pairs.add((prev_char, char))\n",
" prev_char = char\n",
" return pairs\n",
"\n",
"\n",
"def basic_clean(text):\n",
" text = ftfy.fix_text(text)\n",
" text = html.unescape(html.unescape(text))\n",
" return text.strip()\n",
"\n",
"\n",
"def whitespace_clean(text):\n",
" text = re.sub(r'\\s+', ' ', text)\n",
" text = text.strip()\n",
" return text\n",
"\n",
"\n",
"class SimpleTokenizer(object):\n",
" def __init__(self, bpe_path: str = \"bpe_simple_vocab_16e6.txt.gz\"):\n",
" self.byte_encoder = bytes_to_unicode()\n",
" self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n",
" merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n",
" merges = merges[1:49152-256-2+1]\n",
" merges = [tuple(merge.split()) for merge in merges]\n",
" vocab = list(bytes_to_unicode().values())\n",
" vocab = vocab + [v+'</w>' for v in vocab]\n",
" for merge in merges:\n",
" vocab.append(''.join(merge))\n",
" vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n",
" self.encoder = dict(zip(vocab, range(len(vocab))))\n",
" self.decoder = {v: k for k, v in self.encoder.items()}\n",
" self.bpe_ranks = dict(zip(merges, range(len(merges))))\n",
" self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n",
" self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n",
"\n",
" def bpe(self, token):\n",
" if token in self.cache:\n",
" return self.cache[token]\n",
" word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n",
" pairs = get_pairs(word)\n",
"\n",
" if not pairs:\n",
" return token+'</w>'\n",
"\n",
" while True:\n",
" bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n",
" if bigram not in self.bpe_ranks:\n",
" break\n",
" first, second = bigram\n",
" new_word = []\n",
" i = 0\n",
" while i < len(word):\n",
" try:\n",
" j = word.index(first, i)\n",
" new_word.extend(word[i:j])\n",
" i = j\n",
" except:\n",
" new_word.extend(word[i:])\n",
" break\n",
"\n",
" if word[i] == first and i < len(word)-1 and word[i+1] == second:\n",
" new_word.append(first+second)\n",
" i += 2\n",
" else:\n",
" new_word.append(word[i])\n",
" i += 1\n",
" new_word = tuple(new_word)\n",
" word = new_word\n",
" if len(word) == 1:\n",
" break\n",
" else:\n",
" pairs = get_pairs(word)\n",
" word = ' '.join(word)\n",
" self.cache[token] = word\n",
" return word\n",
"\n",
" def encode(self, text):\n",
" bpe_tokens = []\n",
" text = whitespace_clean(basic_clean(text)).lower()\n",
" for token in re.findall(self.pat, text):\n",
" token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n",
" bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n",
" return bpe_tokens\n",
"\n",
" def decode(self, tokens):\n",
" text = ''.join([self.decoder[token] for token in tokens])\n",
" text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n",
" return text\n"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4W8ARJVqBJXs"
},
"source": [
"# Setting up input images and texts\n",
"\n",
"We are going to feed 8 example images and their textual descriptions to the model, and compare the similarity between the corresponding features.\n",
"\n",
"The tokenizer is case-insensitive, and we can freely give any suitable textual descriptions."
]
},
{
"cell_type": "code",
"metadata": {
"id": "tMc1AXzBlhzm"
},
"source": [
"import os\n",
"import skimage\n",
"import IPython.display\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"from collections import OrderedDict\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"# images in skimage to use and their textual descriptions\n",
"descriptions = {\n",
" \"page\": \"a page of text about segmentation\",\n",
" \"chelsea\": \"a facial photo of a tabby cat\",\n",
" \"astronaut\": \"a portrait of an astronaut with the American flag\",\n",
" \"rocket\": \"a rocket standing on a launchpad\",\n",
" \"motorcycle_right\": \"a red motorcycle standing in a garage\",\n",
" \"camera\": \"a person looking at a camera on a tripod\",\n",
" \"horse\": \"a black-and-white silhouette of a horse\", \n",
" \"coffee\": \"a cup of coffee on a saucer\"\n",
"}"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NSSrLY185jSf",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 368
},
"outputId": "f4094535-8b5e-469b-c2b7-e0d02f2d4e9c"
2020-12-17 17:55:12 +01:00
},
"source": [
"images = []\n",
"texts = []\n",
"plt.figure(figsize=(16, 5))\n",
"\n",
"for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(\".png\") or filename.endswith(\".jpg\")]:\n",
" name = os.path.splitext(filename)[0]\n",
" if name not in descriptions:\n",
" continue\n",
"\n",
" image = preprocess(Image.open(os.path.join(skimage.data_dir, filename)).convert(\"RGB\"))\n",
" images.append(image)\n",
" texts.append(descriptions[name])\n",
"\n",
" plt.subplot(2, 4, len(images))\n",
" plt.imshow(image.permute(1, 2, 0))\n",
" plt.title(f\"{filename}\\n{descriptions[name]}\")\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
"plt.tight_layout()\n"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACKYAAAK/CAYAAABZDVdIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydedxuU/n/35/DcczDMVaGIxGRlCQVToo0/ShJojxNGr4NKg1K9UhfUUlp+BbiNJIkjajwGFOUIUTEKRFy5uMMDmf9/ljXttfZz973sO99P/dzzrner9d+3ffea7r22muvda1rrb2WQgg4juM4juM4juM4juM4juM4juM4juM4juM4TtNMGLQAjuM4juM4juM4juM4juM4juM4juM4juM4zoqJT0xxHMdxHMdxHMdxHMdxHMdxHMdxHMdxHMdx+oJPTHEcx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3H6gk9McRzHcRzHcRzHcRzHcRzHcRzHcRzHcRzHcfqCT0xxHMdxHMdxHMdxHMdxHMdxHMdxHMdxHMdx+oJPTHEcx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3H6gk9McRzHcRzHcRzHcRzHcRzHcRzHcRzHcRzHcfqCT0xxHMdxHMdxHMdxHMdxHMdxHMdxHMdxHMdx+oJPTHEcx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3H6gk9McRzHcRzHcRzHcRzHcRzHcRzHcRzHcRzHcfqCT0xxHMdxHMdxHMdxHMdxHMdxHMdxHMdxHMdx+oJPTHEcx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3Ecx3H6gk9McRzHcRzHcRzHcRzHcRzHcRzHcRzHcRzHcfqCT0xxHMdxHMdxHMdxHMdxHMdxHMdxHMdx+oKkIUlB0sigZWmHpCkma2g43uUmDxynH/jElJUUSdOs8hteGdJ1HMdxHMdZ0ZE0bHrWtEHL0inLo8yO4ziO4ziO4ziO4zgZNtlgWNIug5bFWXGRtL6Vs+FBy+I4dVl10AI4juM4juM4jrP8IWkImAJcEEK4cbDSOI7jOI6zvCJpKjAVuDGEcMFgpVm+kHQUsD4wLYQwfcDiOI7jOM7KyhCwNzAdcPvIisES4I5BC1FgfeAz9n94gHI4Tm18YorjOI7jOI7jOHUYonfDy8PEjv5/mhHJcRzHcZzlkKlEI/t3AZ+Y0h1HAVsBI0SdzHEcx3Ecx+mREMJ9wPaDlsNxVjR8YorjOI7jOI7jOAMhhPB14OuDlsNxHMdxHMdxHMdxHMdxHMfpHxMGLYDjOI7jOI7jOI7jOI7jOI7jOI7jOE4rJE2XFCRNlfQkSd+SdK+khZL+JumDkiYk/g+WdKWk2ZLmSvq1pJ1axP9sST+wOBdLeljSxZIOKvE7JCkQV5MFOMtky47pJWE2lXSypNslLZA0R9KfJH1Y0qQKmaZZfMOSJkn6pKSbJc2z6+snfiXpELvPB+we7pN0heXNhuZvLwu7OLtWkfZTJS01v08vcd/C7ucWk2eepNskfUfSi6vibZHeTpLOlHSPpEX23K6W9C5JE7uNryT+KdnzsfPnSzpP0n8kPS7pK2X+KuJ6laTL7BnOlXStpCPMbcTCD7WR59UWx2xJ8y2OQ0v8jQD3JOehcAx3kQdZmCmW3+dYWVlk5fJTLcriE/claQ0rk3fY+/eQxbVtm/R7zjdn+cUnpqxgSNpBsSH+uzVqsyX9VdKpknatCLOKpKMk3WRhZkr6laTntklrbUmfkHSdVSCLJN1paW1RQ/ZNJH3RGrBHLL57JV0j6bOStqoIt7Gkz9t9zrewt0j6X0mTK8JsJOk9kn5uFe08C3ebpC9LenK38lu8Q1Zpjtj5EVapzrU8ukTS/hVhiw1iVYOwWov004ZgkTWm51hcbRtSx3Ecx3E6R8saQ54i6ZuS7lbs1N+Y+NtG0rfNbZGkWYoGgbdLWqVNGo118CUdY/IuknRAwa1jfUo1DC8tZBq2MNNK3Gp3lJM4jpD0R7ufmYod31eZ2xPPr1N5HcdxHGdloq7tpFP7TmanIG7jA3CERhvZp6R+E5tJ6SBCIsOLJZ1vusOj9vszSfu0uN9U99hS0umS/m263T2SviRp3Ybzamo7/UkFW5NdG7a8yGxllxXybaQsror40wGGDSSdolxv/bek0yQ9qSLsMrpconvNU7SFXSZp3zbpP0PSjxUHMxZa/h0nafVWuqLjOI6z0rM18BfgncC6wETi1itfBr4KIOlE4FxgD+J46DrAK4ArVTJ4LulI4HrgMGBzYAGwPrAfcJ6k72tZO85C4EFgiZ3PtfPs+G8h/ucBtwEfAp4OPAasBuwGfAn4o6RNWtzz6sAVwOfsXh8vxL8e8FvgHLvPTYBHgMnAnpY3rwYIIVwB/N3Sf2OLNN8CCLg6hHBHIb2DiNszfwjYkbhLxxKT7a3AWS3iHYWk9wI3WZpTLK61gRcA/wf8VtKa3cTZJr1DgCuBg4A1KORnm7DHAr8kbkm5joXdDZgm6ZQO4/gU8AtgL7u0FrA78CNJRxW8zyRuh53xYOGY36nsCS8ArgUOId6/iOXys8CIpLVbhF0XuJqox28FBGBji+taSduUBWoi35zlnBCCHyvIAbyP2JAFO+YDs5LzkcTvNLv2OeAi+/8oMC/xvxDYoyKtHYh712Z+l1h62flM4IUl4bJ0hwvXtwLuT8I/ZnEsTa69qyS+FwEzEj+LTe7s/F/A00vCfakg+4xC3j0E7FzjGQxleQ2cYv8ft+eQ3svRJWGnJO77EZWeAMy2ODK3CyrSXo+oNKV5MScpC2/M3AZdVv3www8//PBjRTgSXehIorEhEDv884Ebzc+rCrrJbNO5svPfAWtVxH9Qog9kulmqH00v+B+269NK4jop0QleUnDrSp8idjIfSO5jjp1nx3Vd5GErmbP030iuZ84x+TK3PwBrV8R9euKvqI99IHl+Uwddlvzwww8//PBjPB7UsJ3QhX0H2MJ0h6ydX1jQKR4AtjC/U5Lwh5g8qW71lUSGzyV+lzLaJvP5ivvN3A9IdKO5SVoBuA6Y2EReWbiplOh1BT9DjLbrHW35k9mLZhby7fwunvOIxfFh4C77v4Bl7XwPATuUhB0292nAGckzn5OEfRw4qCLtl7Ks3pnqen8APp/FP+j3wQ8//PDDj/FxkPflZwPXZO0rsCZwbNL+f8J0hA9gdhdgJ+B283NuId4XJO3qT4DN7frawCcTXeLYEpmytnSohdwbkOtINwO72fVVgNdZWx6A35WEnWZu84h6zSHAaua2FaabAL9K2vH3A+vbdRHH9I4DDkji/aj5/0uFzBOINqEAvLUkvzId6VLi5AKZ2zrAgcCZhTBDFHSaxO1Act3rI8BGdn014GXESTQB+HaP5WdKonfMA84Dppjbqsn/J/yVxLFPEseZwCZ2fT3gePLyOapMJHkwm6gzHZs8p02t7GV68eQq2XvMg5DI8CfgmUleD5HbIk9rUdZnEVdweRmxDE8gTn66l5L3q9d882PFOQYugB8NPUg4OHmhf0LSWSTOhjwMODm5Ni2pPGYArydvyHYG/mrufypJaz2rcAJxtunOwCrm9lTgh+b2QFahlqQ7XLh+pl2/0yqvCXZ9ElFZOB44sBBmK/KJN98EnmaV3wQLc7G53ZrJl4R9P3AM8ExgVbu2CrAr+USdW7CGtIvnMETekQ7AicB65vYk4AfkitGLCmGnJM9wFvBj8kZwLeDj5MrPK0rS/i75gNPh5MrIjsTO/BOTlAZdXv3www8//PBjRTjIjSHziEaFFyRuTwO2ITfoj2CTO0y/ORJYZG5nlMRdp4M/TMFwb3rRtxL9Yo9CmF70qRF67CyWyZy49dJRfksS/oREH9uEOGjyKHESUcAnpvjhhx9++OFH6UEN2wn17DuV+kDiZ0rStrcaRHhD4u9r5IMaGwKnJm6Hl6SR2mQuAXZKZH8rue72nibyyvxMpcbElMRteq/6TKLTzSZ+cfuq5LntDdydyD+xEDZ7drOIAyjvAtY0t62By839/ixfkrAbEb/8DcAfk/yeS
2020-12-17 17:55:12 +01:00
"text/plain": [
"<Figure size 1152x360 with 8 Axes>"
]
},
"metadata": {
"tags": [],
"image/png": {
"width": 1107,
2020-12-17 17:55:12 +01:00
"height": 351
}
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WEVKsji6WOIX"
},
"source": [
"## Building features\n",
"\n",
"We normalize the images, tokenize each text input, and run the forward pass of the model to get the image and text features."
]
},
{
"cell_type": "code",
"metadata": {
"id": "QwkkczUPBRMh"
},
"source": [
"image_input = torch.tensor(np.stack(images)).cuda()\n",
"image_input -= image_mean[:, None, None]\n",
"image_input /= image_std[:, None, None]"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HBgCanxi8JKw"
},
"source": [
"tokenizer = SimpleTokenizer()\n",
"text_tokens = [tokenizer.encode(\"This is \" + desc) for desc in texts]"
2020-12-17 17:55:12 +01:00
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "w1l_muuhZ_Nk"
},
"source": [
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n",
"sot_token = tokenizer.encoder['<|startoftext|>']\n",
"eot_token = tokenizer.encoder['<|endoftext|>']\n",
2020-12-17 17:55:12 +01:00
"\n",
"for i, tokens in enumerate(text_tokens):\n",
" tokens = [sot_token] + tokens + [eot_token]\n",
2020-12-17 17:55:12 +01:00
" text_input[i, :len(tokens)] = torch.tensor(tokens)\n",
"\n",
"text_input = text_input.cuda()"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZN9I0nIBZ_vW"
},
"source": [
"with torch.no_grad():\n",
" image_features = model.encode_image(image_input).float()\n",
" text_features = model.encode_text(text_input).float()"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cuxm2Gt4Wvzt"
},
"source": [
"## Calculating cosine similarity\n",
"\n",
"We normalize the features and calculate the dot product of each pair."
]
},
{
"cell_type": "code",
"metadata": {
"id": "yKAxkQR7bf3A"
},
"source": [
"image_features /= image_features.norm(dim=-1, keepdim=True)\n",
"text_features /= text_features.norm(dim=-1, keepdim=True)\n",
"similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "C5zvMxh8cU6m",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 830
},
"outputId": "7acf50ee-dead-4e8e-f77d-f3e8a596bf45"
2020-12-17 17:55:12 +01:00
},
"source": [
"count = len(descriptions)\n",
"\n",
"plt.figure(figsize=(20, 14))\n",
"plt.imshow(similarity, vmin=0.1, vmax=0.3)\n",
2020-12-17 17:55:12 +01:00
"# plt.colorbar()\n",
"plt.yticks(range(count), texts, fontsize=18)\n",
"plt.xticks([])\n",
"for i, image in enumerate(images):\n",
" plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin=\"lower\")\n",
"for x in range(similarity.shape[1]):\n",
" for y in range(similarity.shape[0]):\n",
" plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\", size=12)\n",
"\n",
"for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n",
" plt.gca().spines[side].set_visible(False)\n",
"\n",
"plt.xlim([-0.5, count - 0.5])\n",
"plt.ylim([count + 0.5, -2])\n",
"\n",
"plt.title(\"Cosine similarity between text and image features\", size=20)"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Cosine similarity between text and image features')"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACB0AAAY5CAYAAAATvfRQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd7xsV13w/893Zs45t+XeVJKQdkMJEFoeQwchCQpIERABlWIoKiLiIyD4E+QJWBFUFBClBhvSRKRIDYEghF5CCSUQQjrJTbn1nDMz6/fH2nPPnn2mn3pvPu/Xaycz+66999pt7Tl7fddakVJCkiRJkiRJkiRJkiRpXLW1zoAkSZIkSZIkSZIkSTowGXQgSZIkSZIkSZIkSZImYtCBJEmSJEmSJEmSJEmaiEEHkiRJkiRJkiRJkiRpIgYdSJIkSZIkSZIkSZKkiRh0IEmSJEmSJEmSJEmSJmLQgSRJkiRJkiRJkiRJmohBB5IkSZIkSZIkSZIkaSIGHUiSJEmSJEmSJEmSpIkYdCBJkiRJkiRJkiRJkiZi0IEkSZIkSZIkSZIkSZqIQQeSJEmSJEmSJEmSJGkiBh1IkiRJkiRJkiRJkqSJGHQgSZIkSdIKiYitEXFGRDw9Iv4gIl4SEc+NiKdExP0iYuNa53GlRUQqTeeudX5uKYrrrnzsz17rPE1inP2IiO2VtOesXk5Xxi35/omIS0v7fv5a50eSDnQRcW75ubIC67fc1sgiYqr4m+i9EfHjiNhV+d3z6rXOo6TxNNY6A5IkSZIkHUwiYjNwNvAk4N4MDvhvRsSXgH8F3p5S2rHyOZQkSZKktRERJwH/DdxtrfMiafnY04EkSZIkScskIp4OXAq8Frgvw//ubgD3KdJfHhF/FhGHrGgmJR2UbGG6Pq10y2Id2CLi/NL1cela52c1RcQ5lVbN29c6T5JWXkRMAf/FLSzg4GDphUwaxJ4OJEmSJElaomKYhH8GfrnHP7eBbwFXA9cBW4FjgdsD5QCDjcAfAfcAHrqS+ZUkSZKkNfA44LTS94uAPyv+v6c0/+bVzJSkpTPoQJIkSZKkJYiIGeBDwBmVf7oY+AvgQyml63osNw2cCTwBeCoLf6PPrFhmJUmSJGnt/FLp8yzw8ymla9YqM5KWj0EHkiRJkiQtzavoDjhI5B4LXpVSavZbKKU0B3wE+EhEvAL4K+DRK5jPNZFSirXOwy1RSul84IA/9gfLfkzK+0eSdKBIKW1f6zzogHB66fPnDDiQDh7DxpaUJEmSJEl9RMSjgeeUZiXgaSmlvxwUcFCVUvpeSukxwAuAkZeTJEmSpAPIrUqfr1yzXEhadvZ0IEmSJEnSBCKiBvxNZfbrUkpvm3SdKaW/joj3Li1nkiRJkrQubSl9nl+zXEhadgYdSJIkSZI0mccCtyl9vxL4w6WuNKX0w3HSR8TdgbuQWw1NA9cCPwI+WwzhMJGIOBL4GeC2wDagDuwurf8bKaXdk65/wjzdA7gjcBywF7gCOD+ldP0yrPtQ4P7ArYEjyft6Nbnb158sdf1j5uVk4DTgeOAQoF3k5wrgEuBb4/SksQz52Qw8EDiBfGx2ABemlL42ZLkNwAOAU8n7cR3wjWLZtKKZHlNENIA7FdOx5Bfie8j7+k3gayml1jJv8zjgnsX2DgeuB96eUrppObdzsIqIewGnkO/Z3cBlwCdTSruWYd3TwH2B7eSytU0u+76RUvr6Utd/MIiIOvn6vT1wFPk987XAxcAXUkrtZdjGqpXLK/l8WW0REcDdyWXvrYAN5HNzCfm3gZV8rFm5fyxwH/LzdAPwU/L98q0lrneGPNTXycChwFXAD8nne1n3YbWU7slbk3sB+w75ntw7ZLm7ksumo4F9wKXAx1NKOyfMRwB3IF8nnd9ls+Tr5HvAF5fye7u0naPJv7WOI5enl5N/71201HX32NYdyL8zb0W+7q8DfgxcMOz4rjfFb9T7k8/NUeRzcy3wpZTS95a47tuQy9GTgK3k63AH+e+gC1NKe5ay/gNRUdY8gFyGHQPMke/LrwxZbkWe5+vx79VbvJSSk5OTk5OTk5OTk5OTk5PTmBNwPnk4hc70slXc9kbgj4CfVPJQnnYCbwWOH3PdDwQ+CrQGrDuRX7x9gRxo0RiwvvIy5w7Zds+0wFOAb/fJRwt4O3DChMfyIcW5bA7Y168Cv7jC5zSApwNfH3LcE/ll2oeBJwxY3xmVZc4eNy35peDrgZv75OMLwM/0uT7/FLipz3I/BB4+4nEZZz+2V9KeM2Td24CnAe8fsI+d6Sbgb4Fbj3FOzy2vozT//sAn6H2PnTbO/QOcM8L10mvaXiz/Z5X5vzDBtfvcyjqeskz3xKWldZ5fzKsBv02u6Om1X3uANwCHTbjNOwD/Ti4/+x27y4HnAdNjXLejTmeX1vHV0vzPj5D36rncMyiPxTKPrSzzkBG2cxzwj+QgmX778VNyGXDIhOdh2cvlfvcSK/h8GZCXsye8Ps4Yst7DgFeQgyD7reNm4HXAUUPW9bbKcm8YY//+orLs+4Eo/m37hPt+zjId+7Uq9+8AvI/cqrvXtr4DPGKC/dlYnPMb+6z3cvLvxelB+VvGa/vS0vrPHzct+bfQM4vj0Wt/dgC/22d9jwQu6rPcvuK63DDGcX0C8E5yhfyg62QP8Bbg9hMeszsAH6J/efdV4JcmOcY99un/I/8G67cve4F/o/iNsIzXxdlDjmGv6dwh67wn8IHi3PZbx/eKbddGzOcU8Ahy+XfFkPzNAe+hx+/gIdf6qNOic7uEc7+9su5zBqQ9o5L27GL+ocA/ADf0yOurB6xvRf7OYpn/XnVavmnNM+Dk5OTk5OTk5OTk5OTkdKBNwCYWvzjevkrbPpXccmPUl1Z7gCeNuO4/neClWAIOHbDOcV4gdqUl99zwzyPm4Urg1DGO4yHkCoBx9vOdwMwKXU8fmeC4f23AOs+opD17nLTkFn2Dglo6027grNK6jmG0wIk28OsjHJtx9mN7Je05Q9b96QmO+fXl/R2y/nPLyxbz/oDBL15XO+hgO90vbN8zwfVbPt83ABuX6b64tLTe88nlwX+NuH9XAXcfY1sBvJz+FYK9pm/SpzK6x3U76nR2aR2vKs1vAtuG7MOFPdb3oCHLvLaUdhbYNCT9s8jPlFH35yfAXcc4DytWLleWOZcVfL6MkJezJ7w+zhiwzsfQuzKo33QjcOaA9W0m91pRXuaJI+zbQ8jle2eZy4EjSv++fcJ9P2eZjv1alPu/zOBApvL0+2Psy4n0D8CqTv9LrjRclL/lnFhC0AG5df9/jLg/r6+s6xUjLvc/DAnGKtY3atlQnvYAvzrm8fpVBleal6dXjXuMS9u5H8Mr0cvTvnH3Zcj2z57geJ7bZ11T5ODCcdZ1PgP+Vimt++UT5LMJPH+Ma33kPC/l/qost72y7nMGpD2jkvZscq8Yg/ZhUdABK/s8X/a/V52Wb3J4BUmSJEmSxncfuocs/HFK6dKV3mgxlMJ55K7Yy35ErgDbRx7y4WfIlWiQWzb9S0RsTim9YcC6nwm8uDJ7FvgaueJoH7kL1qOBO5O7GV1pryO3QoVcGfhFcgXGFHBX4HaltMcC74qI/5OGdHNbdMX5MfJLtLLrga+QW7VtLrZxcunfHw9si4hfSMvQdXjJG8gVNWU3kCtzryG/0NxK7t71juTKspV0BLknheOL71cBXyZXmpwI3JuF638T8M6iq9695OCJu5X24Qvklom3Irfw31D8WwBviIjPpSV2f7sEtcr3a8gtnneQr/dDyUE+5WvgcOBDEXHvNGY3+xHxROCvSrMuKba3h9zd7L3Gyv0ySCldGhEfAx5azHpURByVUvrpKMsXQxzcrTTr39PKdc38OuDRxedEvld/CMwUedheSnsM8LGIuH9K6fuDVlp0nf02Fsqajr3FNq4svt+OXGZ0ytY7A5+NiHumlK6eZIeGOA94fvG5DjwI+O9eCSNiK3CPHv/0YOBTA7bx4NLngV1FR8SfsvgZMU9uofgTcjl1UpGPTvlwPHBBRDwgpfTNAflYi
2020-12-17 17:55:12 +01:00
"text/plain": [
"<Figure size 1440x1008 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"image/png": {
"width": 1038,
"height": 796
},
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "alePijoXy6AH"
},
"source": [
"# Zero-Shot Image Classification\n",
"\n",
"You can classify images using the cosine similarity (times 100) as the logits to the softmax operation."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Nqu4GlfPfr-p",
"colab": {
"base_uri": "https://localhost:8080/"
2020-12-17 17:55:12 +01:00
},
"outputId": "c5cd7575-79a3-49a2-cd51-36f7ee2da1b7"
2020-12-17 17:55:12 +01:00
},
"source": [
"from torchvision.datasets import CIFAR100\n",
"\n",
"cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess, download=True)"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
2020-12-17 17:55:12 +01:00
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C4S__zCGy2MT",
"outputId": "6e0c62da-e7ee-44f0-8e20-6d5d0020437d"
2020-12-17 17:55:12 +01:00
},
"source": [
"text_descriptions = [f\"This is a photo of a {label}\" for label in cifar100.classes]\n",
"text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]\n",
2020-12-17 17:55:12 +01:00
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n",
"\n",
"for i, tokens in enumerate(text_tokens):\n",
" text_input[i, :len(tokens)] = torch.tensor(tokens)\n",
"\n",
"text_input = text_input.cuda()\n",
"text_input.shape"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([100, 77])"
]
},
"metadata": {
"tags": []
},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "c4z1fm9vCpSR"
},
"source": [
"with torch.no_grad():\n",
" text_features = model.encode_text(text_input).float()\n",
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
"\n",
"text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n",
"top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 931
},
"id": "T6Ju_6IBE2Iz",
"outputId": "d604f8d8-ac64-4cbc-fb84-f4b6c2643686"
2020-12-17 17:55:12 +01:00
},
"source": [
"plt.figure(figsize=(16, 16))\n",
"\n",
"for i, image in enumerate(images):\n",
" plt.subplot(4, 4, 2 * i + 1)\n",
" plt.imshow(image.permute(1, 2, 0))\n",
" plt.axis(\"off\")\n",
"\n",
" plt.subplot(4, 4, 2 * i + 2)\n",
" y = np.arange(top_probs.shape[-1])\n",
" plt.grid()\n",
" plt.barh(y, top_probs[i])\n",
" plt.gca().invert_yaxis()\n",
" plt.gca().set_axisbelow(True)\n",
" plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])\n",
" plt.xlabel(\"probability\")\n",
"\n",
"plt.subplots_adjust(wspace=0.5)\n",
"plt.show()"
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAByEAAAckCAYAAADrtXjRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdS6wsW37n9e9/PSLysfc55966rio3xthIFkgeWj1uSwyRmCMGPUdiwohBSyDBoGdMGKPuCcyRkJi0cDO3kWBgJmCDu+12Vd3HOXvvzIxYjz+DtSJf+3Eedcq3XPf/sU7dvTPjsSIycjtX/uK/lqgqxhhjjDHGGGOMMcYYY4wxxhjzubjvuwHGGGOMMcYYY4wxxhhjjDHGmN8sFkIaY4wxxhhjjDHGGGOMMcYYYz4rCyGNMcYYY4wxxhhjjDHGGGOMMZ+VhZDGGGOMMcYYY4wxxhhjjDHGmM/KQkhjjDHGGGOMMcYYY4wxxhhjzGdlIaQxxhhjjDHGGGOMMcYYY4wx5rOyENIYY4wxxhhjjDHGGGOMMcYY81lZCGmMMcYYY4wxxhhjjDHGGGOM+awshDTGGGOMMcYYY4wxxhhjjDHGfFYWQhpjjDHGGGOMMcYYY4wxxhhjPisLIY0xxhhjjDHGGGOMMcYYY4wxn5WFkMYYY4wxxhhjjDHGGGOMMcaYzyp83w0wxhhjzC9PRP4CeAX85ffcFGOM+XXwe8A7Vf3977shxhhjjDHm14P1m40x5sLv8XfQb7YQ0hhjjPnN8Gocxy//8A//8MvvuyHmsbu7OwBub2+/55aY59hr9OvtY1+fP//zP2e/3/8qm2SMMcYYY/7+sX7zJ7C+0qex8/Zp7Lx9mk85b39X/WYLIY0xxpjfDH/5u7/7u1/+6Z/+6ffdDvOEP/mTPwHgj//4j7/Xdpjn2Wv06+1jX58/+qM/4s/+7M/+8lfWIGOMMcYY8/eR9Zs/gfWVPo2dt09j5+3TfMp5+7vqN9uckMYYY4wxxhhjjDHGGGOMMcaYz8pCSGOMMcYYY4wxxhhjjDHGGGPMZ2UhpDHGGGOMMcYYY4wxxhhjjDHms7IQ0hhjjDHGGGOMMcYYY4wxxhjzWVkIaYwxxhhjjDHGGGOMMcYYY4z5rCyENMYYY4wxxhhjjDHGGGOMMcZ8VhZCGmOMMcYYY4wxxhhjjDHGGGM+KwshjTHGGGOMMcYYY4wxxhhjjDGflYWQxhhjjDHGGGOMMcYYY4wxxpjPykJIY4wxxhhjjDHGGGOMMcYYY8xnZSGkMcYYY4wxxhhjjDHGGGOMMeazshDSGGOMMcYYY4wxxhhjjDHGGPNZWQhpjDHGGGOMMcYYY4wxxhhjjPmsLIQ0xhhjjDHGGGOMMcYYY4wxxnxWFkIaY4wxxhhjjDHGGGOMMcYYYz6r8H03wBhjjDGfx//7rvJ7/8X//H03w7zkf7HX59eevUa/Vv7yn/6H33cTjDHGGGPMbxDrN/8SrK/0aey8fRo7bx/s173fbJWQxhhjjDHGGGOMMcYYY4wxxpjPykJIY4wxxhhjjDHGGGOMMcYYY8xnZSGkMcYYY4wxxhhjjDHGGGOMMeazshDSGGOMMcYYY4wxxhhjjDHGGPNZWQhpjDHGGGOMMcYYY4wxxhhjjPmsLIQ0xhhjjDHGGGOMMcYYY4wxxnxWFkIaY4wxxhhjjDHGGGOMMcYYYz4rCyGNMcYYY4wxxhhjjDHGGGOMMZ+VhZDGGGOMMcYYY4wxxhhjjDHGmM8qfOqK/+Q/+IeqAAgxDoCgIhSBL3/8Y9Y3N6xvb4jjSPCRECIiimpBtZJVqFVRBScwDgHvPc55cqmUXMgpoxRQRdC2DypVC6UUnASca/9yyuSSyTkRo2eInnEMpDTxcH/Pw7t7pmliHEbGYWCzXeGjw3lPDBERh6LUWsl5JqXEPM/EOCD9mL0XRForqAreoSIoShCPc4CHnGcEQXDc3e94uH/gcJhYjZHVesQ5Yb8/INUhIjgvHA4H5nlm2s/c73ZUVVSVaZ5Zb1bc3Gz5nd/9He7v7yglMQQBFXa7iZ9/c8+//tufM6eCKtxsVtxuV9xuRm42a3Ip1NperaJKyoXDNPOjL9/gvYNaGcYIKFoLOSUOUyKlQirKze0NwzDgRPjmu7eUUojR8dUXb3DOUUVZrzeUUklz4uuvv+Hd/YGiwk//wb/Nl19+AcDPf/E1f/uzn7Fajfz0x1/xxZvXOCc4EW5u31BKppSCd0LOGa0V0cq7t9+x3+25290RfWCaEncPe/7ib9+ymzJTqohT/sFXr/jJj17zB7/371BzoapSpPBbX/6I1RAJzpFVOcwTu92et2/f4vp5fvv2gW/uD7x7OPBvvr3ju7sduRRC8ASveAeDCD/98hWvtis2Y+Tt3YFpztSqvLpZ8fp2zTBExAnfvttxOEwcDgeCA++E6B2rGBjHyDgOjJs16/WaEAdKqRz2e1Iu5FLwIeKc4L3j9XaNCCiKjwNDjIQQEFy7Xktmmu5BKwAKOBVUlVpKf1wRKk4gxkgIHsQBgojgQ0Dcsk5lmjOlKlUB52n3Kwha5fiaee8Ah9LWQ9o2xTsEjyCAgIJIu9a9CzjnEZHjvkXa+7A97nDiwLXnkLZtXf7wqOL6Mj5GnI+oKqUUtChKez01V2opaMmUPFNrRft7YPmvKKACAuIU51zfp0NrbcekSq2g5P5zX9d7nA8471n+QPwn//U/Xf5UGGOMMcYYA8Bv/cE/0vcv9fGkf/JUvX5cAL14XOTTPqZ+6nrvtzROzn5uv+v1AUHvV5yvvSzjAOV9zVTV3jN5esHz4zzfv4g8+n1Z5ngEerqzW+XyaF5qz3IkcraCIpw3sXVX9KLlov1x+qLLtoRj3+kDGrBs7fTYM+fw5e09sR307CToe87H6dn3NlufbeLjRR9trJ+tvgHpGzt7GfrjglAv29X/K9pf++UFe7G9V0+qgNSzB3o/mVMDnj7PyzqPaweee13keGGcbb0vqxeP9vNw3qTr5l/t47m/B3r2el+q/Vo9b5Dy7q//T+s3G2OMMT8wnxxCeu8ppVJ7wOCcQ4E5zRwOeyR4wjjgvOtf7kdiDCAeVcUDOdcWDpTawwhBnCBVWjIprgWQLZM4fVCsUM7WCd4hEnBZqKWFDaVUcqmEGNlutwTvub+7b5+JtVK14voHv1orzskxhFSt1FrIOeG8a3GiCLXSgkYBenilCrU3ynkhOE/PLhAUrUsHsIV/sVZUHCkXtFScCAFPVVBtx5xzaSe5dySc9y3s4NTR0N7zyKWyPxyYU2FOhaLKqtLC1SFQaiHnFk6uh4G73Y45Z6q2j4RewHnHerWm5ExKM0LB9VB1cI4xDozDACjeOURgvRrxThAvOOfZbLZM09TC41pQKs5HtjdrxnFknmbu7nbMc2EYhBhXx2BKRHEOlLbt4BzBufYazBOlFFLJFK2sfSB7RbUFY6qunf9aW3dBoOTCvJ8oWpDBX7xOVUt/rSo32wGnQs3K/f2OGD0xBpyTFu7VClXAKVTI/fHWiXVE78lSUS2UoqSccR4GFxm8owbf1qFdb+11KIiAd46hFGrJFBFyVXKppFyYUkan1EI7L0Qn7Vw7iALBgToF51u3Tlqw3y4XwTuHF9+uY+faW4kK1LZv79v7VTlenzUXEO3vgRZeLtdx60fW1nnQFhRqDwhPl6lrAaRz4BwO199rSye9AkLVCtoC+rb68iXJaWeKIlV6B68Hi/SuY9XWuffS29/e71q1Bfq9TVVbR70KVC2Iaju+ctYpFunfFPT3MksHX5eEsi3qQHD9de8hpHN478G1ILjW8ql/So0xxhhjjAEeh14ftg48jmeuw72n9/X9WdK6YyrUPApUl2Uv23oWyXE61ifCy34+5dhvOYWI58tcr/OBR9ByrU+ImE+vsyz3RD6x8eXH0zk67kuWwOxynQ9uymd77Z/ZzvHye9zGxXL8S6vPr/3zsPe8ycfXrwed18d7vCoehcpPRNDPnCx9Jqo+XYJybM+T618d1/MLPx3wXb7/n3+dnvtboS+8tHL2v5z1f
2020-12-17 17:55:12 +01:00
"text/plain": [
"<Figure size 1152x1152 with 16 Axes>"
]
},
"metadata": {
"tags": [],
"image/png": {
"width": 912,
2020-12-17 17:55:12 +01:00
"height": 914
},
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0OENu-DQLzQY"
},
"source": [
""
],
2021-03-08 03:58:54 +01:00
"execution_count": null,
2020-12-17 17:55:12 +01:00
"outputs": []
}
]
}