58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import math
|
|
import scipy.io as sio
|
|
|
|
class OneDirection(nn.Module):
|
|
def __init__(self,dim_z,vocab_size=1000, **kwargs):
|
|
super(OneDirection, self).__init__()
|
|
print("\napproach: ", "one_direction\n")
|
|
self.dim_z = dim_z
|
|
self.vocab_size = vocab_size
|
|
self.w = nn.Parameter(torch.randn(1, self.dim_z))
|
|
self.criterion = nn.MSELoss()
|
|
|
|
|
|
def transform(self,z,y,step_sizes,**kwargs):
|
|
if y is not None:
|
|
assert(len(y) == z.shape[0])
|
|
|
|
interim = step_sizes * self.w
|
|
|
|
z_transformed = z + interim
|
|
z_transformed = z.norm() * z_transformed / z_transformed.norm()
|
|
|
|
return(z_transformed)
|
|
|
|
def compute_loss(self, current, target, batch_start, lossfile):
|
|
loss = self.criterion(current,target)
|
|
with open(lossfile, 'a') as file:
|
|
file.writelines(str(batch_start)+",mse_loss,"+str(loss)+"\n")
|
|
file.writelines(str(batch_start) + ",overall_loss," + str(loss)+"\n")
|
|
return loss
|
|
|
|
class ClassDependent(nn.Module):
|
|
def __init__(self,dim_z,vocab_size=1000, **kwargs):
|
|
super(ClassDependent, self).__init__()
|
|
print("\napproach: ", "class_dependent\n")
|
|
self.dim_z = dim_z
|
|
self.vocab_size = vocab_size
|
|
self.NN_output = nn.Linear(self.vocab_size, self.dim_z)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
|
|
def transform(self,z,y,step_sizes,**kwargs):
|
|
assert (y is not None)
|
|
interim = step_sizes * self.NN_output(y)
|
|
z_transformed = z + interim
|
|
z_transformed = z.norm() * z_transformed / z_transformed.norm()
|
|
return(z_transformed)
|
|
|
|
def compute_loss(self, current, target, batch_start, lossfile):
|
|
loss = self.criterion(current,target)
|
|
with open(lossfile, 'a') as file:
|
|
file.writelines(str(batch_start)+",mse_loss,"+str(loss)+"\n")
|
|
file.writelines(str(batch_start) + ",overall_loss," + str(loss)+"\n")
|
|
return loss |