import numpy as np
from scipy.stats import truncnorm
import PIL.ImageDraw
import PIL.ImageFont
def truncated_z_sample(batch_size, dim_z, truncation=1):
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z))
return truncation * values
def imgrid(imarray, cols=5, pad=1):
if imarray.dtype != np.uint8:
imarray = np.uint8(imarray)
# raise ValueError('imgrid input imarray must be uint8')
pad = int(pad)
assert pad >= 0
cols = int(cols)
assert cols >= 1
N, H, W, C = imarray.shape
rows = int(np.ceil(N / float(cols)))
batch_pad = rows * cols - N
assert batch_pad >= 0
post_pad = [batch_pad, pad, pad, 0]
pad_arg = [[0, p] for p in post_pad]
imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
H += pad
W += pad
grid = (imarray
.reshape(rows, cols, H, W, C)
.transpose(0, 2, 1, 3, 4)
.reshape(rows * H, cols * W, C))
if pad:
grid = grid[:-pad, :-pad]
return grid
def annotate_outscore(array, outscore):
for i in range(array.shape[0]):
I = PIL.Image.fromarray(np.uint8(array[i,:,:,:]))
draw = PIL.ImageDraw.Draw(I)
font = PIL.ImageFont.truetype("/data/scratch/swamiviv/projects/stylegan2-ada-pytorch/clip_steering/arial.ttf", int(array.shape[1]/8.5))
message = str(round(np.squeeze(outscore)[i], 2))
x, y = (0, 0)
w, h = font.getsize(message)
#print(w, h)
draw.rectangle((x, y, x + w, y + h), fill='white')
draw.text((x, y), message, fill="black", font=font)
array[i, :, :, :] = np.array(I)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'): = name
self.fmt = fmt
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)