73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import numpy as np
|
|
from scipy.stats import truncnorm
|
|
import PIL.ImageDraw
|
|
import PIL.ImageFont
|
|
|
|
|
|
def truncated_z_sample(batch_size, dim_z, truncation=1):
|
|
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z))
|
|
return truncation * values
|
|
|
|
def imgrid(imarray, cols=5, pad=1):
|
|
if imarray.dtype != np.uint8:
|
|
imarray = np.uint8(imarray)
|
|
# raise ValueError('imgrid input imarray must be uint8')
|
|
pad = int(pad)
|
|
assert pad >= 0
|
|
cols = int(cols)
|
|
assert cols >= 1
|
|
N, H, W, C = imarray.shape
|
|
rows = int(np.ceil(N / float(cols)))
|
|
batch_pad = rows * cols - N
|
|
assert batch_pad >= 0
|
|
post_pad = [batch_pad, pad, pad, 0]
|
|
pad_arg = [[0, p] for p in post_pad]
|
|
imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
|
|
H += pad
|
|
W += pad
|
|
grid = (imarray
|
|
.reshape(rows, cols, H, W, C)
|
|
.transpose(0, 2, 1, 3, 4)
|
|
.reshape(rows * H, cols * W, C))
|
|
if pad:
|
|
grid = grid[:-pad, :-pad]
|
|
return grid
|
|
|
|
def annotate_outscore(array, outscore):
|
|
for i in range(array.shape[0]):
|
|
I = PIL.Image.fromarray(np.uint8(array[i,:,:,:]))
|
|
draw = PIL.ImageDraw.Draw(I)
|
|
font = PIL.ImageFont.truetype("/data/scratch/swamiviv/projects/stylegan2-ada-pytorch/clip_steering/arial.ttf", int(array.shape[1]/8.5))
|
|
message = str(round(np.squeeze(outscore)[i], 2))
|
|
x, y = (0, 0)
|
|
w, h = font.getsize(message)
|
|
#print(w, h)
|
|
draw.rectangle((x, y, x + w, y + h), fill='white')
|
|
draw.text((x, y), message, fill="black", font=font)
|
|
array[i, :, :, :] = np.array(I)
|
|
return(array)
|
|
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, name, fmt=':f'):
|
|
self.name = name
|
|
self.fmt = fmt
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
def __str__(self):
|
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
|
return fmtstr.format(**self.__dict__) |