Add end to end steering + clip pipeline with styleGAN2
This commit is contained in:
parent
8a665a683d
commit
fbf0019f73
|
@ -0,0 +1,7 @@
|
|||
0. Install deps to run clip and stylegan2-ada-pytorch in a python virtual env.
|
||||
1. Install deps from stylegan2-ada-pytorch repo. Major ones are pytorch >= 1.7 and CUDA >= 11.0
|
||||
2. Download ffhq-pretrained stylegan2 model from the above repo.
|
||||
3. Use the virtual environment from above and Run through generation_demo.ipynb - this code samples images from
|
||||
a styleGAN2 network and scores them using CLIP
|
||||
4. ganalyze_with_clip.py is the main code that runs the steering pipeline with a generative model and CLIP. Change output
|
||||
paths and model paths from within the code.
|
|
@ -0,0 +1,127 @@
|
|||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = "../pretrained/bpe_simple_vocab_16e6.txt.gz"):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
|
@ -0,0 +1,73 @@
|
|||
import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
import PIL.ImageDraw
|
||||
import PIL.ImageFont
|
||||
|
||||
|
||||
def truncated_z_sample(batch_size, dim_z, truncation=1):
|
||||
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z))
|
||||
return truncation * values
|
||||
|
||||
def imgrid(imarray, cols=5, pad=1):
|
||||
if imarray.dtype != np.uint8:
|
||||
imarray = np.uint8(imarray)
|
||||
# raise ValueError('imgrid input imarray must be uint8')
|
||||
pad = int(pad)
|
||||
assert pad >= 0
|
||||
cols = int(cols)
|
||||
assert cols >= 1
|
||||
N, H, W, C = imarray.shape
|
||||
rows = int(np.ceil(N / float(cols)))
|
||||
batch_pad = rows * cols - N
|
||||
assert batch_pad >= 0
|
||||
post_pad = [batch_pad, pad, pad, 0]
|
||||
pad_arg = [[0, p] for p in post_pad]
|
||||
imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
|
||||
H += pad
|
||||
W += pad
|
||||
grid = (imarray
|
||||
.reshape(rows, cols, H, W, C)
|
||||
.transpose(0, 2, 1, 3, 4)
|
||||
.reshape(rows * H, cols * W, C))
|
||||
if pad:
|
||||
grid = grid[:-pad, :-pad]
|
||||
return grid
|
||||
|
||||
def annotate_outscore(array, outscore):
|
||||
for i in range(array.shape[0]):
|
||||
I = PIL.Image.fromarray(np.uint8(array[i,:,:,:]))
|
||||
draw = PIL.ImageDraw.Draw(I)
|
||||
font = PIL.ImageFont.truetype("/data/scratch/swamiviv/projects/stylegan2-ada-pytorch/clip_steering/arial.ttf", int(array.shape[1]/8.5))
|
||||
message = str(round(np.squeeze(outscore)[i], 2))
|
||||
x, y = (0, 0)
|
||||
w, h = font.getsize(message)
|
||||
#print(w, h)
|
||||
draw.rectangle((x, y, x + w, y + h), fill='white')
|
||||
draw.text((x, y), message, fill="black", font=font)
|
||||
array[i, :, :, :] = np.array(I)
|
||||
return(array)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
import scipy.io as sio
|
||||
|
||||
class OneDirection(nn.Module):
|
||||
def __init__(self,dim_z,vocab_size=1000, **kwargs):
|
||||
super(OneDirection, self).__init__()
|
||||
print("\napproach: ", "one_direction\n")
|
||||
self.dim_z = dim_z
|
||||
self.vocab_size = vocab_size
|
||||
self.w = nn.Parameter(torch.randn(1, self.dim_z))
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
|
||||
def transform(self,z,y,step_sizes,**kwargs):
|
||||
if y is not None:
|
||||
assert(len(y) == z.shape[0])
|
||||
|
||||
interim = step_sizes * self.w
|
||||
|
||||
z_transformed = z + interim
|
||||
z_transformed = z.norm() * z_transformed / z_transformed.norm()
|
||||
|
||||
return(z_transformed)
|
||||
|
||||
def compute_loss(self, current, target, batch_start, lossfile):
|
||||
loss = self.criterion(current,target)
|
||||
with open(lossfile, 'a') as file:
|
||||
file.writelines(str(batch_start)+",mse_loss,"+str(loss)+"\n")
|
||||
file.writelines(str(batch_start) + ",overall_loss," + str(loss)+"\n")
|
||||
return loss
|
||||
|
||||
class ClassDependent(nn.Module):
|
||||
def __init__(self,dim_z,vocab_size=1000, **kwargs):
|
||||
super(ClassDependent, self).__init__()
|
||||
print("\napproach: ", "class_dependent\n")
|
||||
self.dim_z = dim_z
|
||||
self.vocab_size = vocab_size
|
||||
self.NN_output = nn.Linear(self.vocab_size, self.dim_z)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
|
||||
def transform(self,z,y,step_sizes,**kwargs):
|
||||
assert (y is not None)
|
||||
interim = step_sizes * self.NN_output(y)
|
||||
z_transformed = z + interim
|
||||
z_transformed = z.norm() * z_transformed / z_transformed.norm()
|
||||
return(z_transformed)
|
||||
|
||||
def compute_loss(self, current, target, batch_start, lossfile):
|
||||
loss = self.criterion(current,target)
|
||||
with open(lossfile, 'a') as file:
|
||||
file.writelines(str(batch_start)+",mse_loss,"+str(loss)+"\n")
|
||||
file.writelines(str(batch_start) + ",overall_loss," + str(loss)+"\n")
|
||||
return loss
|
|
@ -0,0 +1,202 @@
|
|||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as torch_transforms
|
||||
import ganalyze_transformations as transformations
|
||||
import ganalyze_common_utils as common
|
||||
import pickle
|
||||
import os
|
||||
import pathlib
|
||||
sys.path.append(os.path.abspath(os.getcwd()))
|
||||
sys.path.append('/data/scratch/swamiviv/projects/stylegan2-ada-pytorch/')
|
||||
from clip_classifier_utils import SimpleTokenizer
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
level=logging.INFO,
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def gan_output_transform(imgs):
|
||||
# Input:
|
||||
# img: NCHW
|
||||
#
|
||||
# Output
|
||||
# img_np: HWC RGB image
|
||||
|
||||
imgs = (imgs * 127.5 + 128).clamp(0, 255).float()
|
||||
return imgs
|
||||
|
||||
|
||||
def clip_input_transform(images):
|
||||
# Input
|
||||
# img_np: torch tensor of shape NHWC, RGB image
|
||||
#
|
||||
# Output
|
||||
# image_input: torch tensor of shape NHWC
|
||||
|
||||
image_mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
image_std = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
transform = torch.nn.Sequential(
|
||||
torch_transforms.Resize((256, 256)),
|
||||
torch_transforms.CenterCrop((224, 224)),
|
||||
torch_transforms.Normalize(image_mean, image_std),
|
||||
)
|
||||
|
||||
return transform(images)
|
||||
|
||||
def get_clip_scores(image_inputs, encoded_text, model, class_index=0):
|
||||
#TODO: clarify class index
|
||||
image_inputs = clip_input_transform(image_inputs).to(device)
|
||||
image_feats = model.encode_image(image_inputs).float()
|
||||
image_feats = F.normalize(image_feats, p=2, dim=-1)
|
||||
|
||||
similarity_scores = torch.matmul(image_feats, torch.transpose(encoded_text, 0, 1))
|
||||
similarity_scores = similarity_scores.to(device)
|
||||
return similarity_scores.narrow(dim=-1, start=class_index, length=1).squeeze(dim=-1)
|
||||
|
||||
def get_clip_probs(image_inputs, encoded_text, model, class_index=0):
|
||||
image_inputs = clip_input_transform(image_inputs).to(device)
|
||||
image_feats = model.encode_image(image_inputs).float()
|
||||
image_feats = F.normalize(image_feats, p=2, dim=-1)
|
||||
|
||||
clip_probs = (100.0 * torch.matmul(image_feats, torch.transpose(encoded_text, 0, 1))).softmax(dim=-1)
|
||||
clip_probs = clip_probs.to(device)
|
||||
|
||||
return clip_probs.narrow(dim=-1, start=class_index, length=1).squeeze(dim=-1)
|
||||
|
||||
# Set up GAN
|
||||
gan_model_path = '../pretrained/ffhq.pkl'
|
||||
# Initialize GAN generator and transforms
|
||||
with open(gan_model_path, 'rb') as f:
|
||||
G = pickle.load(f)['G_ema']
|
||||
G.eval()
|
||||
G.to(device)
|
||||
latent_space_dim = G.z_dim
|
||||
|
||||
# Set up clip classifier
|
||||
clip_model_path = '../pretrained/clip_ViT-B-32.pt'
|
||||
clip_model = torch.jit.load(clip_model_path)
|
||||
clip_model.eval()
|
||||
clip_model.to(device)
|
||||
input_resolution = clip_model.input_resolution.item()
|
||||
context_length = clip_model.context_length.item()
|
||||
vocab_size = clip_model.vocab_size.item()
|
||||
|
||||
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}")
|
||||
print("Input resolution:", input_resolution)
|
||||
print("Context length:", context_length)
|
||||
print("Vocab size:", vocab_size)
|
||||
|
||||
# Extract text features for clip
|
||||
attributes = ["an evil face", "a radiant face", "a criminal face", "a beautiful face", "a handsome face", "a smart face"]
|
||||
class_index = 2 #which attribute do we want to maximize
|
||||
tokenizer = SimpleTokenizer()
|
||||
sot_token = tokenizer.encoder['<|startoftext|>']
|
||||
eot_token = tokenizer.encoder['<|endoftext|>']
|
||||
text_descriptions = [f"This is a photo of {label}" for label in attributes]
|
||||
text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]
|
||||
text_inputs = torch.zeros(len(text_tokens), clip_model.context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(text_tokens):
|
||||
text_inputs[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
# These are held constant through the optimization, akin to labels
|
||||
text_inputs = text_inputs.to(device)
|
||||
with torch.no_grad():
|
||||
text_features = clip_model.encode_text(text_inputs).float()
|
||||
text_features = F.normalize(text_features, p=2, dim=-1)
|
||||
text_features.to(device)
|
||||
|
||||
# Setting up Transformer
|
||||
# --------------------------------------------------------------------------------------------------------------
|
||||
transformer_params = ['OneDirection', 'None']
|
||||
transformer = transformer_params[0]
|
||||
transformer_arguments = transformer_params[1]
|
||||
if transformer_arguments != "None":
|
||||
key_value_pairs = transformer_arguments.split(",")
|
||||
key_value_pairs = [pair.split("=") for pair in key_value_pairs]
|
||||
transformer_arguments = {pair[0]: pair[1] for pair in key_value_pairs}
|
||||
else:
|
||||
transformer_arguments = {}
|
||||
|
||||
transformation = getattr(transformations, transformer)(latent_space_dim, vocab_size, **transformer_arguments)
|
||||
transformation = transformation.to(device)
|
||||
|
||||
# function that is used to score the (attribute, image) pair
|
||||
scoring_fun = get_clip_probs
|
||||
|
||||
|
||||
# Training
|
||||
# --------------------------------------------------------------------------------------------------------------
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(transformation.parameters(), lr=0.0002)
|
||||
losses = common.AverageMeter(name='Loss')
|
||||
|
||||
# training settings
|
||||
optim_iter = 0
|
||||
batch_size = 6
|
||||
train_alpha_a = -0.5 # Lower limit for step sizes
|
||||
train_alpha_b = 0.5 # Upper limit for step sizes
|
||||
num_samples = 400000 # Number of samples to train for
|
||||
|
||||
# create training set
|
||||
#np.random.seed(seed=0)
|
||||
truncation = 1
|
||||
zs = common.truncated_z_sample(num_samples, latent_space_dim, truncation)
|
||||
|
||||
checkpoint_dir = f'/data/scratch/swamiviv/projects/stylegan2-ada-pytorch/clip_steering/results_maximize_{attributes[class_index]}_probability'
|
||||
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# loop over data batches
|
||||
for batch_start in range(0, num_samples, batch_size):
|
||||
|
||||
# input batch
|
||||
s = slice(batch_start, min(num_samples, batch_start + batch_size))
|
||||
z = torch.from_numpy(zs[s]).type(torch.FloatTensor).to(device)
|
||||
y = None
|
||||
step_sizes = (train_alpha_b - train_alpha_a) * \
|
||||
np.random.random(size=(batch_size)) + train_alpha_a # sample step_sizes
|
||||
step_sizes_broadcast = np.repeat(step_sizes, latent_space_dim).reshape([batch_size, latent_space_dim])
|
||||
step_sizes_broadcast = torch.from_numpy(step_sizes_broadcast).type(torch.FloatTensor).to(device)
|
||||
|
||||
# ganalyze steps
|
||||
gan_images = G(z, None)
|
||||
gan_images = gan_output_transform(gan_images)
|
||||
out_scores = scoring_fun(
|
||||
image_inputs=gan_images, encoded_text=text_features, model=clip_model, class_index=class_index,
|
||||
)
|
||||
# TODO: ignore z vectors with less confident clip scores
|
||||
target_scores = out_scores + torch.from_numpy(step_sizes).to(device).float()
|
||||
|
||||
z_transformed = transformation.transform(z, None, step_sizes_broadcast)
|
||||
gan_images_transformed = G(z_transformed, None)
|
||||
gan_images_transformed = gan_output_transform(gan_images_transformed).to(device)
|
||||
out_scores_transformed = scoring_fun(
|
||||
image_inputs=gan_images_transformed, encoded_text=text_features, model=clip_model, class_index=class_index,
|
||||
).to(device).float()
|
||||
|
||||
# compute loss
|
||||
loss = transformation.criterion(out_scores_transformed, target_scores)
|
||||
|
||||
# backwards
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print loss
|
||||
losses.update(loss.item(), batch_size)
|
||||
if optim_iter % 100 == 0:
|
||||
logger.info(f'[Maximizing score for {attributes[class_index]}] Progress: [{batch_start}/{num_samples}] {losses}')
|
||||
|
||||
if optim_iter % 500 == 0:
|
||||
logger.info(f"saving checkpoint at iteration {optim_iter}")
|
||||
torch.save(transformation.state_dict(), os.path.join(checkpoint_dir, "pytorch_model_{}.pth".format(batch_start)))
|
||||
optim_iter = optim_iter + 1
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue