remove unused/repeated imports

This commit is contained in:
Aleksey Morozov 2021-01-24 12:03:04 +02:00
parent 6bc0bd8873
commit 90c1b12d63

View File

@ -36,73 +36,6 @@
"Make sure you're running a GPU runtime; if not, select \"GPU\" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will print the CUDA version of the runtime if it has a GPU, and install PyTorch 1.7.1." "Make sure you're running a GPU runtime; if not, select \"GPU\" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will print the CUDA version of the runtime if it has a GPU, and install PyTorch 1.7.1."
] ]
}, },
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0BpdJkdBssk9",
"outputId": "f26d1899-1e21-427a-f730-9451be9a2572"
},
"source": [
"import subprocess\n",
"\n",
"CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
"print(\"CUDA version:\", CUDA_version)\n",
"\n",
"if CUDA_version == \"10.0\":\n",
" torch_version_suffix = \"+cu100\"\n",
"elif CUDA_version == \"10.1\":\n",
" torch_version_suffix = \"+cu101\"\n",
"elif CUDA_version == \"10.2\":\n",
" torch_version_suffix = \"\"\n",
"else:\n",
" torch_version_suffix = \"+cu110\""
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"CUDA version: 10.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RBVr18E5tse8",
"outputId": "d4647553-46fa-4cc5-8a41-b152abc0b4d2"
},
"source": [
"! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
"Requirement already satisfied: torch==1.7.1+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.1+cu101)\n",
"Requirement already satisfied: torchvision==0.8.2+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.2+cu101)\n",
"Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8)\n",
"Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (1.19.4)\n",
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (0.8)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (3.7.4.3)\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.2+cu101) (7.0.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)\n"
],
"name": "stdout"
}
]
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
@ -238,7 +171,7 @@
"id": "d6cpiIFHp9N6" "id": "d6cpiIFHp9N6"
}, },
"source": [ "source": [
"from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n", "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor\n",
"from PIL import Image\n", "from PIL import Image\n",
"\n", "\n",
"preprocess = Compose([\n", "preprocess = Compose([\n",
@ -343,8 +276,9 @@
"\n", "\n",
"\n", "\n",
"def get_pairs(word):\n", "def get_pairs(word):\n",
" \"\"\"Return set of symbol pairs in a word.\n", " \"\"\"\n",
" Word is represented as tuple of symbols (symbols being variable-length strings).\n", " Return set of symbol pairs in a word. Word is represented as tuple of\n",
" symbols (symbols being variable-length strings).\n",
" \"\"\"\n", " \"\"\"\n",
" pairs = set()\n", " pairs = set()\n",
" prev_char = word[0]\n", " prev_char = word[0]\n",
@ -381,7 +315,8 @@
" self.encoder = dict(zip(vocab, range(len(vocab))))\n", " self.encoder = dict(zip(vocab, range(len(vocab))))\n",
" self.decoder = {v: k for k, v in self.encoder.items()}\n", " self.decoder = {v: k for k, v in self.encoder.items()}\n",
" self.bpe_ranks = dict(zip(merges, range(len(merges))))\n", " self.bpe_ranks = dict(zip(merges, range(len(merges))))\n",
" self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n", " self.cache = {'<|startoftext|>': '<|startoftext|>',\n",
" '<|endoftext|>': '<|endoftext|>'}\n",
" self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n", " self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n",
"\n", "\n",
" def bpe(self, token):\n", " def bpe(self, token):\n",
@ -405,7 +340,7 @@
" j = word.index(first, i)\n", " j = word.index(first, i)\n",
" new_word.extend(word[i:j])\n", " new_word.extend(word[i:j])\n",
" i = j\n", " i = j\n",
" except:\n", " except ValueError:\n",
" new_word.extend(word[i:])\n", " new_word.extend(word[i:])\n",
" break\n", " break\n",
"\n", "\n",
@ -429,14 +364,16 @@
" bpe_tokens = []\n", " bpe_tokens = []\n",
" text = whitespace_clean(basic_clean(text)).lower()\n", " text = whitespace_clean(basic_clean(text)).lower()\n",
" for token in re.findall(self.pat, text):\n", " for token in re.findall(self.pat, text):\n",
" token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n", " token = ''.join(self.byte_encoder[b]\n",
" bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n", " for b in token.encode('utf-8'))\n",
" bpe_tokens.extend(self.encoder[bpe_token]\n",
" for bpe_token in self.bpe(token).split(' '))\n",
" return bpe_tokens\n", " return bpe_tokens\n",
"\n", "\n",
" def decode(self, tokens):\n", " def decode(self, tokens):\n",
" text = ''.join([self.decoder[token] for token in tokens])\n", " text = ''.join([self.decoder[token] for token in tokens])\n",
" text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n", " text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n",
" return text\n" " return text"
], ],
"execution_count": 9, "execution_count": 9,
"outputs": [] "outputs": []
@ -460,15 +397,8 @@
"id": "tMc1AXzBlhzm" "id": "tMc1AXzBlhzm"
}, },
"source": [ "source": [
"import os\n",
"import skimage\n", "import skimage\n",
"import IPython.display\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"from collections import OrderedDict\n",
"import torch\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n", "%config InlineBackend.figure_format = 'retina'\n",
@ -518,7 +448,7 @@
" plt.xticks([])\n", " plt.xticks([])\n",
" plt.yticks([])\n", " plt.yticks([])\n",
"\n", "\n",
"plt.tight_layout()\n" "plt.tight_layout()"
], ],
"execution_count": 11, "execution_count": 11,
"outputs": [ "outputs": [
@ -582,7 +512,8 @@
"id": "w1l_muuhZ_Nk" "id": "w1l_muuhZ_Nk"
}, },
"source": [ "source": [
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n", "text_input = torch.zeros(len(text_tokens), model.context_length,\n",
" dtype=torch.long)\n",
"sot_token = tokenizer.encoder['<|startoftext|>']\n", "sot_token = tokenizer.encoder['<|startoftext|>']\n",
"eot_token = tokenizer.encoder['<|endoftext|>']\n", "eot_token = tokenizer.encoder['<|endoftext|>']\n",
"\n", "\n",
@ -651,10 +582,12 @@
"plt.yticks(range(count), texts, fontsize=18)\n", "plt.yticks(range(count), texts, fontsize=18)\n",
"plt.xticks([])\n", "plt.xticks([])\n",
"for i, image in enumerate(images):\n", "for i, image in enumerate(images):\n",
" plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin=\"lower\")\n", " plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6),\n",
" origin=\"lower\")\n",
"for x in range(similarity.shape[1]):\n", "for x in range(similarity.shape[1]):\n",
" for y in range(similarity.shape[0]):\n", " for y in range(similarity.shape[0]):\n",
" plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\", size=12)\n", " plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\",\n",
" size=12)\n",
"\n", "\n",
"for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n", "for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n",
" plt.gca().spines[side].set_visible(False)\n", " plt.gca().spines[side].set_visible(False)\n",
@ -720,7 +653,8 @@
"source": [ "source": [
"from torchvision.datasets import CIFAR100\n", "from torchvision.datasets import CIFAR100\n",
"\n", "\n",
"cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess, download=True)" "cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess,\n",
" download=True)"
], ],
"execution_count": 18, "execution_count": 18,
"outputs": [ "outputs": [
@ -743,9 +677,12 @@
"outputId": "6e0c62da-e7ee-44f0-8e20-6d5d0020437d" "outputId": "6e0c62da-e7ee-44f0-8e20-6d5d0020437d"
}, },
"source": [ "source": [
"text_descriptions = [f\"This is a photo of a {label}\" for label in cifar100.classes]\n", "text_descriptions = [f\"This is a photo of a {label}\"\n",
"text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]\n", " for label in cifar100.classes]\n",
"text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)\n", "text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token]\n",
" for desc in text_descriptions]\n",
"text_input = torch.zeros(len(text_tokens), model.context_length,\n",
" dtype=torch.long)\n",
"\n", "\n",
"for i, tokens in enumerate(text_tokens):\n", "for i, tokens in enumerate(text_tokens):\n",
" text_input[i, :len(tokens)] = torch.tensor(tokens)\n", " text_input[i, :len(tokens)] = torch.tensor(tokens)\n",
@ -835,17 +772,6 @@
} }
} }
] ]
},
{
"cell_type": "code",
"metadata": {
"id": "0OENu-DQLzQY"
},
"source": [
""
],
"execution_count": 21,
"outputs": []
} }
] ]
} }