remove unused/repeated imports

This commit is contained in:
Aleksey Morozov 2021-01-24 12:03:04 +02:00
parent 6bc0bd8873
commit 90c1b12d63
1 changed files with 32 additions and 106 deletions

View File

@ -36,73 +36,6 @@
"Make sure you're running a GPU runtime; if not, select \"GPU\" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will print the CUDA version of the runtime if it has a GPU, and install PyTorch 1.7.1."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0BpdJkdBssk9",
"outputId": "f26d1899-1e21-427a-f730-9451be9a2572"
},
"source": [
"import subprocess\n",
"\n",
"CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
"print(\"CUDA version:\", CUDA_version)\n",
"\n",
"if CUDA_version == \"10.0\":\n",
" torch_version_suffix = \"+cu100\"\n",
"elif CUDA_version == \"10.1\":\n",
" torch_version_suffix = \"+cu101\"\n",
"elif CUDA_version == \"10.2\":\n",
" torch_version_suffix = \"\"\n",
"else:\n",
" torch_version_suffix = \"+cu110\""
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"CUDA version: 10.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RBVr18E5tse8",
"outputId": "d4647553-46fa-4cc5-8a41-b152abc0b4d2"
},
"source": [
"! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
"Requirement already satisfied: torch==1.7.1+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.1+cu101)\n",
"Requirement already satisfied: torchvision==0.8.2+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.2+cu101)\n",
"Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8)\n",
"Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (1.19.4)\n",
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (0.8)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (3.7.4.3)\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.2+cu101) (7.0.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
@ -238,7 +171,7 @@
"id": "d6cpiIFHp9N6"
},
"source": [
"from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n",
"from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor\n",
"from PIL import Image\n",
"\n",
"preprocess = Compose([\n",
@ -330,7 +263,7 @@
" To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n",
" And avoids mapping to whitespace/control characters the bpe code barfs on.\n",
" \"\"\"\n",
" bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n",
" bs = list(range(ord(\"!\"), ord(\"~\")+1)) + list(range(ord(\"¡\"), ord(\"¬\")+1)) + list(range(ord(\"®\"), ord(\"ÿ\")+1))\n",
" cs = bs[:]\n",
" n = 0\n",
" for b in range(2**8):\n",
@ -343,8 +276,9 @@
"\n",
"\n",
"def get_pairs(word):\n",
" \"\"\"Return set of symbol pairs in a word.\n",
" Word is represented as tuple of symbols (symbols being variable-length strings).\n",
" \"\"\"\n",
" Return set of symbol pairs in a word. Word is represented as tuple of\n",
" symbols (symbols being variable-length strings).\n",
" \"\"\"\n",
" pairs = set()\n",
" prev_char = word[0]\n",
@ -381,20 +315,21 @@
" self.encoder = dict(zip(vocab, range(len(vocab))))\n",
" self.decoder = {v: k for k, v in self.encoder.items()}\n",
" self.bpe_ranks = dict(zip(merges, range(len(merges))))\n",
" self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n",
" self.cache = {'<|startoftext|>': '<|startoftext|>',\n",
" '<|endoftext|>': '<|endoftext|>'}\n",
" self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n",
"\n",
" def bpe(self, token):\n",
" if token in self.cache:\n",
" return self.cache[token]\n",
" word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n",
" word = tuple(token[:-1]) + (token[-1] + '</w>',)\n",
" pairs = get_pairs(word)\n",
"\n",
" if not pairs:\n",
" return token+'</w>'\n",
" return token + '</w>'\n",
"\n",
" while True:\n",
" bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n",
" bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))\n",
" if bigram not in self.bpe_ranks:\n",
" break\n",
" first, second = bigram\n",
@ -405,7 +340,7 @@
" j = word.index(first, i)\n",
" new_word.extend(word[i:j])\n",
" i = j\n",
" except:\n",
" except ValueError:\n",
" new_word.extend(word[i:])\n",
" break\n",
"\n",
@ -429,14 +364,16 @@
" bpe_tokens = []\n",
" text = whitespace_clean(basic_clean(text)).lower()\n",
" for token in re.findall(self.pat, text):\n",
" token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n",
" bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n",
" token = ''.join(self.byte_encoder[b]\n",
" for b in token.encode('utf-8'))\n",
" bpe_tokens.extend(self.encoder[bpe_token]\n",
" for bpe_token in self.bpe(token).split(' '))\n",
" return bpe_tokens\n",
"\n",
" def decode(self, tokens):\n",
" text = ''.join([self.decoder[token] for token in tokens])\n",
" text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n",
" return text\n"
" return text"
],
"execution_count": 9,
"outputs": []
@ -460,15 +397,8 @@
"id": "tMc1AXzBlhzm"
},
"source": [
"import os\n",
"import skimage\n",
"import IPython.display\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"from collections import OrderedDict\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
@ -481,7 +411,7 @@
" \"rocket\": \"a rocket standing on a launchpad\",\n",
" \"motorcycle_right\": \"a red motorcycle standing in a garage\",\n",
" \"camera\": \"a person looking at a camera on a tripod\",\n",
" \"horse\": \"a black-and-white silhouette of a horse\", \n",
" \"horse\": \"a black-and-white silhouette of a horse\",\n",
" \"coffee\": \"a cup of coffee on a saucer\"\n",
"}"
],
@ -518,7 +448,7 @@
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
"plt.tight_layout()\n"
"plt.tight_layout()"
],
"execution_count": 11,
"outputs": [
@ -582,7 +512,8 @@
"id": "w1l_muuhZ_Nk"
},
"source": [
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n",
"text_input = torch.zeros(len(text_tokens), model.context_length,\n",
" dtype=torch.long)\n",
"sot_token = tokenizer.encoder['<|startoftext|>']\n",
"eot_token = tokenizer.encoder['<|endoftext|>']\n",
"\n",
@ -651,10 +582,12 @@
"plt.yticks(range(count), texts, fontsize=18)\n",
"plt.xticks([])\n",
"for i, image in enumerate(images):\n",
" plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin=\"lower\")\n",
" plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6),\n",
" origin=\"lower\")\n",
"for x in range(similarity.shape[1]):\n",
" for y in range(similarity.shape[0]):\n",
" plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\", size=12)\n",
" plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\",\n",
" size=12)\n",
"\n",
"for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n",
" plt.gca().spines[side].set_visible(False)\n",
@ -720,7 +653,8 @@
"source": [
"from torchvision.datasets import CIFAR100\n",
"\n",
"cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess, download=True)"
"cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess,\n",
" download=True)"
],
"execution_count": 18,
"outputs": [
@ -743,9 +677,12 @@
"outputId": "6e0c62da-e7ee-44f0-8e20-6d5d0020437d"
},
"source": [
"text_descriptions = [f\"This is a photo of a {label}\" for label in cifar100.classes]\n",
"text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]\n",
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n",
"text_descriptions = [f\"This is a photo of a {label}\"\n",
" for label in cifar100.classes]\n",
"text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token]\n",
" for desc in text_descriptions]\n",
"text_input = torch.zeros(len(text_tokens), model.context_length,\n",
" dtype=torch.long)\n",
"\n",
"for i, tokens in enumerate(text_tokens):\n",
" text_input[i, :len(tokens)] = torch.tensor(tokens)\n",
@ -835,17 +772,6 @@
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0OENu-DQLzQY"
},
"source": [
""
],
"execution_count": 21,
"outputs": []
}
]
}