import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from IPython.display import clear_output
import os, sys
import json
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchsummary import summary
from torch.autograd import grad
import torch.autograd as autograd
from src.plotters import plot_low_dim_equal
from src.tools_wo_crop import unfreeze, freeze
from scipy import linalg
from sklearn import datasets, manifold
from mpl_toolkits.mplot3d import Axes3D
SEED = 9999
torch.manual_seed(SEED)
cuda = True if torch.cuda.is_available() else False
BATCH_SIZE = 400
GPU_DEVICE = 0
dim = 2
K_G = 16
K_psi = 1
lam_gp = 1
lr_G = 1e-3
lr_psi = 1e-3
if cuda:
print('Using GPU: ', GPU_DEVICE)
torch.cuda.set_device(GPU_DEVICE)
class SyntheticDataGenerator(object):
"""superclass of all synthetic data. WARNING: doesn't raise StopIteration so it loops forever!"""
def __iter__(self):
return self
def __next__(self):
return self.get_batch()
def get_batch(self):
raise NotImplementedError()
def float_tensor(self, batch):
return torch.from_numpy(batch).type(torch.FloatTensor)
class StandardGaussianGenerator(SyntheticDataGenerator):
"""samples from Standard Gaussian."""
def __init__(self,
batch_size: int=256,
scale: float=1.,
eps_noise: float=1.):
self.batch_size = batch_size
scale = scale
self.eps_noise = eps_noise
def get_batch(self):
batch = []
for _ in range(self.batch_size):
point = np.random.randn(2) * self.eps_noise
batch.append(point)
batch = np.array(batch, dtype='float32')
batch = self.float_tensor(batch)
batch = batch[torch.randperm(batch.size(0)), :]
return batch
class SCurveGenerator(SyntheticDataGenerator):
"""samples from S Curve."""
def __init__(self,
batch_size: int=256,
scale: float = 5,
noise: float= 0.01):
self.batch_size = batch_size
self.scale = scale
self.noise = noise
def get_batch(self):
X, c = datasets.make_s_curve(self.batch_size, noise=self.noise)
batch = X[:,[2,0]]*self.scale
batch = self.float_tensor(batch)
batch = batch[torch.randperm(batch.size(0)), :]
return batch
X, c = datasets.make_s_curve(300, random_state=0)
fig = plt.figure(figsize=(15,8))
ax = fig.add_subplot(251, projection="3d")
ax.scatter(X[:,0],X[:,1],X[:,2], c=c)
ax.view_init(4,-72)
X_sampler = StandardGaussianGenerator(BATCH_SIZE)
Y_sampler = SCurveGenerator(BATCH_SIZE)
X_samples = next(X_sampler)
Y_samples = next(Y_sampler)
X_samples.shape, Y_samples.shape
plt.rcParams.update({'font.size': 30})
plt.figure(figsize=(10,10))
plt.plot(X_samples[:,0].numpy(), X_samples[:,1].numpy(),'og',label=r'$X \sim \mu$', color='forestgreen')
plt.plot(Y_samples[:,0].numpy(), Y_samples[:,1].numpy(),'s', label=r'$Y \sim \nu$', color='peru')
plt.grid()
plt.legend(loc='upper left')
class MLP_G(nn.Module):
def __init__(self, features = 128):
super(MLP_G, self).__init__()
self.W0b0 = nn.Linear(in_features=dim, out_features=features)
self.W1b1 = nn.Linear(in_features=features, out_features=features)
self.W2b2 = nn.Linear(in_features=features, out_features=features)
self.W3b3 = nn.Linear(in_features=features, out_features=dim)
def forward(self, x):
x = F.leaky_relu(self.W0b0(x))
x = F.leaky_relu(self.W1b1(x))
x = F.leaky_relu(self.W2b2(x))
op = self.W3b3(x)
return op
# Generator
G = nn.Sequential(
MLP_G(features=128),
)
G = G.cuda()
summary(G,(dim,))
class MLP_D(nn.Module):
def __init__(self, features = 128):
super(MLP_D, self).__init__()
self.W0b0 = nn.Linear(in_features=dim, out_features=features)
self.W1b1 = nn.Linear(in_features=features, out_features=features)
self.W2b2 = nn.Linear(in_features=features, out_features=features)
self.W3b3 = nn.Linear(in_features=features, out_features=1)
def forward(self, x):
x = F.leaky_relu(self.W0b0(x))
x = F.leaky_relu(self.W1b1(x))
x = F.leaky_relu(self.W2b2(x))
op = self.W3b3(x)
return op
# Potential
psi = nn.Sequential(
MLP_D(features=128),
)
psi = psi.cuda()
summary(psi, (dim,))
Q = lambda X: X.detach() # Embedding Q
def GradientPenalty(psi, G, X, Y):
batch_size = X.shape[0]
G_X = G(X)
eta = torch.FloatTensor(batch_size,1).uniform_(0,1).cuda()
interpolated = eta * Y + (1 - eta) * G_X
interpolated = interpolated.cuda()
interpolated.requires_grad_(True)
psi_interpolated = torch.mean(psi(interpolated))
gradients = autograd.grad(
outputs=psi_interpolated, inputs=interpolated,
grad_outputs=torch.ones(psi_interpolated.size()).to(interpolated),
create_graph=True, retain_graph=True
)[0]
return ((gradients.norm(2, dim=1) - 1)**2).mean()
def Loss(psi, G, Q, X, Y):
G_X = G(X)
loss = (Q(X) * G_X).mean(dim=1).mean() - psi(G_X).mean() + psi(Y).mean()
return loss
n = BATCH_SIZE
X_fixed = next(X_sampler)
Y_fixed = next(Y_sampler)
X_fixed = X_fixed[:n].cuda()
Y_fixed = Y_fixed[:n].cuda()
G_opt = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.99))
psi_opt = torch.optim.Adam(psi.parameters(), lr=lr_psi, betas=(0.5, 0.99))
FID_history = []
fig, axes = plot_low_dim_equal(G, Q, X_fixed, Y_fixed)
plt.close(fig)
G(X_fixed).type(), Y_fixed.type()
gp_loss = GradientPenalty(psi, G, X_fixed, Y_fixed)
gp_loss
loss = Loss(psi,G, Q, X_fixed, Y_fixed)
loss
for it in range(10000+1):
##########################################################
## Outer minimization loop
##########################################################
freeze(G); unfreeze(psi)
for k_psi in range(K_psi):
X = next(X_sampler).cuda()
Y = next(Y_sampler).cuda()
gp_loss = GradientPenalty(psi, G, X, Y)
psi_loss = Loss(psi, G, Q, X, Y) + lam_gp * gp_loss
psi_opt.zero_grad(); psi_loss.backward(); psi_opt.step()
gc.collect(); torch.cuda.empty_cache()
##########################################################
## Inner maximization loop
##########################################################
freeze(psi); unfreeze(G)
for k_G in range(K_G):
X = next(X_sampler).cuda()
Y = next(Y_sampler).cuda()
G_loss = -Loss(psi, G, Q, X, Y)
G_opt.zero_grad(); G_loss.backward(); G_opt.step()
# del G_loss
gc.collect(); torch.cuda.empty_cache()
if it % 50 == 0:
clear_output(wait=True)
print('Iteration: ', it, '\tG_loss: ', np.round(G_loss.item(),5), '\tGp_loss: ', np.round(gp_loss.item(),5), '\tPsi_loss: ', np.round(psi_loss.item(),5))
fig, axes = plot_low_dim_equal(G, Q, X_fixed, Y_fixed)
plt.close(fig)
if it % 1000 == 0:
gc.collect()
torch.cuda.empty_cache()
freeze(G), freeze(psi)
fig, axes = plot_low_dim_equal(G, Q, X_fixed, Y_fixed)
fig.savefig('./output/Toy/SCurve/otm_scurve.pdf', bbox_inches='tight')