import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.optim as optim
import random
import numpy as np
import math

from torch.utils.data import TensorDataset, DataLoader

# seed = 2005, 2006, \cdots, 2024 : 20 in total
seed = 2024
print("Random Seed: ",seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
np.random.seed(seed)

N = 1000
dy = 4
VE = True # vanilla cGAN if False
num_epochs = 3000
if VE:
    num_epochs = np.floor(100 * np.pow(N,1/4))

batch_size = 512
lr_d = 0.001
lr_g = 0.0001
dx = 1
dz = dy

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator is common for both cGAN algorithms.
# spectral norm is applied to ensure label Lipschitz continuity of p_g
# Due to the scaling of x, the network is ensured to be sqrt(dy)-Label Lipschitz continuous.

class Generator(nn.Module):
    def __init__(self, ngpu=1, dx=dx, dy=dy, dz=dz):
        super(Generator, self).__init__()
        self.dx = dx
        self.dy = dy
        self.dz = dz
        self.inner_dim = 128
        self.linear = nn.Sequential(
            spectral_norm(nn.Linear(self.dz+self.dx, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.dy)), nn.Sigmoid(True)
        )

    def forward(self, z, x):
        x = x * np.sqrt(self.dy)
        input = torch.cat((z, x), 1)
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.linear, input, range(self.ngpu))
        else:
            output = self.linear(input)
        return output
    
  
# Discriminator for VE-cGAN is Barron Space.
# We introduce a new implementation that models the parameter function a as the neural network.

class Barron_a(nn.Module):
    def __init__(self, ngpu=1, dx=dx, dy=dy, dw=dy):
        super(Barron_a, self).__init__()
        self.dx = dx
        self.dy = dy
        self.dw = dw
        self.inner_dim = 128
        self.linear = nn.Sequential(
            spectral_norm(nn.Linear(self.dw+1+self.dx, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, self.inner_dim)), nn.ReLU(True),
            spectral_norm(nn.Linear(self.inner_dim, 1))
        )
    
    def forward(self, wbx):
        # For every use of Barron_a, we will use concatenated input of w,b and x named wbx.
        input = wbx
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.linear, input, range(self.ngpu))
        else:
            output = self.linear(input)
        return output
    

class Barron_Discriminator(nn.Module):
    def __init__(self, ngpu=1, dx=dx, dy=dy, dw=dy):
        super(Barron_Discriminator, self).__init__()
        self.dx = dx
        self.dy = dy
        self.dw = dw
        self.a = Barron_a(ngpu=ngpu, dx=self.dx, dy=self.dy, dw=self.dw)
        self.rho_sample = 2048

    def forward(self, y, x):
        batch_size = y.size(0)
        # Sampling of (w,b) from uniform distribution in |w|+|b| \le 1
        # Can easily sampled by sampling the difference of order statistic.
        spx = torch.cat((torch.zeros((self.rho_sample,1)),torch.sort(torch.rand((self.rho_sample,self.dw+1))).values,torch.ones((self.rho_sample,1))),1)
        diff = spx[:,1:] - spx[:,:-1]
        sgn = 2*torch.randint(0,2,size=(self.rho_sample,self.dw+2))-1
        wb = torch.mul(diff,sgn)[:,:-1]
        wb.unsqueeze_(0)
        wb = wb.expand(batch_size,self.rho_sample,self.dw+1)
        x.unsqueeze_(1)
        x = x.expand(batch_size,self.rho_sample,self.dx)
        wbx = torch.cat([wb,x],dim=2).view(batch_size*self.rho_sample,self.dw+self.dx+1)
        a = self.a(wbx)
        wb = wb.view(batch_size*self.rho_sample,self.dw+1)
        y.unsqueeze_(1)
        y = y.expand(batch_size,self.rho_sample,self.dy).view(batch_size*self.rho_sample,self.dy)
        yo = torch.cat([y,torch.ones((batch_size*self.rho_sample,1))],dim=1)
        out = torch.relu(torch.inner(wb,yo))
        a = a.view(batch_size,self.rho_sample)
        out = out.view(batch_size,self.rho_sample)
        return torch.inner(a,out)

    def Barron_norm(self, x):
        # Used as penalizer in VE training
        spx = torch.cat((torch.zeros((self.rho_sample,1)),torch.sort(torch.rand((self.rho_sample,self.dw+1))).values,torch.ones((self.rho_sample,1))),1)
        diff = spx[:,1:] - spx[:,:-1]
        sgn = 2*torch.randint(0,2,size=(self.rho_sample,self.dw+2))-1
        wb = torch.mul(diff,sgn)[:,:-1]
        wb = torch.mul(diff,sgn)[:,:-1]
        wb.unsqueeze_(0)
        wb = wb.expand(batch_size,self.rho_sample,dy+1)
        x.unsqueeze_(1)
        x = x.expand(batch_size,self.rho_sample,dx)
        wbx = torch.cat([wb,x],dim=2)
        a = self.a(wbx)
        a.view(batch_size,self.rho_sample)
        return torch.linalg.vector_norm(a,dim=1)

# Create Dataset
# x = x : 1-dimensional label, unit normal distribution
# y = (sig(x[0]),sig(x[1]),...,sig(x[dx-1]),0,...,0)/2 + Uniform([-1/2,1/2]^dy) : lies in [-1,1]^dy, 1-label Lipschitz
# We sample (X,Y), N dataset from true distribution
X = np.random.randn(1,N)
Y = np.tile(X,(dy,1)) * 0.5 + np.random.rand(N,dy) * 0.5


# Network setting. Initialized by He initializer.
ngpu = 2

netG = Generator()
netD = Barron_Discriminator()

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

netG = Generator(ngpu).to(device)

netG.apply(weights_init)

netD = Barron_Discriminator(ngpu).to(device)

netD.apply(weights_init)

compiled_netG = torch.compile(netG)
compiled_netD = torch.compile(netD)

# Training


X = torch.from_numpy(X).float().to(device)
Y = torch.from_numpy(Y).float().to(device)

dataset = TensorDataset(X,Y)
dataloader = DataLoader(dataset, batch_size=batch_size,shuffle=True)

optimizerD = optim.Adam(compiled_netD.parameters(),lr=lr_d, betas = (0.5,0.999))
optimizerG = optim.Adam(compiled_netG.parameters(),lr=lr_g, betas = (0.5,0.999))

# Optimal lengthscale of O(N^{-1/8})

sigma = 0.4 * np.pow(N,-1/8)

dist = np.zeros(num_epochs)

gerror = np.zeros(num_epochs)

def calculate_2_wasserstein_dist(X, Y):
    '''
    Calulates the two components of the 2-Wasserstein metric:
    The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]
    For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),
    this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
    Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf
    Input shape: [b, n] (e.g. batch_size x num_features)
    Output shape: scalar
    code from https://gist.github.com/Flunzmas/6e359b118b0730ab403753dcc2a447df
    '''

    if X.shape != Y.shape:
        raise ValueError("Expecting equal shapes for X and Y!")

    # the linear algebra ops will need some extra precision -> convert to double
    X, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double()  # [n, b]
    mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True)  # [n, 1]
    n, b = X.shape
    fact = 1.0 if b < 2 else 1.0 / (b - 1)

    # Cov. Matrix
    E_X = X - mu_X
    E_Y = Y - mu_Y
    cov_X = torch.matmul(E_X, E_X.t()) * fact  # [n, n]
    cov_Y = torch.matmul(E_Y, E_Y.t()) * fact

    # calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf
    # The eigenvalues for M are real-valued.
    C_X = E_X * math.sqrt(fact)  # [n, n], "root" of covariance
    C_Y = E_Y * math.sqrt(fact)
    M_l = torch.matmul(C_X.t(), C_Y)
    M_r = torch.matmul(C_Y.t(), C_X)
    M = torch.matmul(M_l, M_r)
    S = torch.linalg.eigvals(M) + 1e-15  # add small constant to avoid infinite gradients from sqrt(0)
    sq_tr_cov = S.sqrt().abs().sum()

    # plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
    trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov  # scalar

    # |mu_X - mu_Y|^2
    diff = mu_X - mu_Y  # [n, 1]
    mean_term = torch.sum(torch.mul(diff, diff))  # scalar

    # put it together
    return (trace_term + mean_term).float()

for epoch in range(num_epochs):
    for batch_idx, samples in enumerate(dataloader):
        X_train, Y_train = samples
        N_batch = X_train.size(dim=0)
        if VE:
            # VE-cGAN

            # Train discriminator first, in large timescale.
            compiled_netD.zero_grad()
            z = torch.randn(N_batch,dz,device=device)
            xp = np.random.multivariate_normal([0] * dx, np.identity(dx) ,N_batch)
            x = np.random.multivariate_normal([0] * dx, (sigma ** 2) * np.identity(dx), N_batch) + xp
            xp = torch.from_numpy(xp).to(device).float()
            x = torch.from_numpy(x).to(device).float()
            loss_d = torch.mean(compiled_netD(compiled_netG(z,x).detach(),xp))
            loss_xy = torch.zeros(N_batch)
            for i in range(N_batch):
                ddx = xp[i] - X_train
                q = torch.exp(- torch.square(torch.linalg.vector_norm(ddx,dim=1) )/ 2.0)
                loss_xy[i] = torch.inner(q,torch.flatten(compiled_netD(Y_train,torch.tile(xp[i],(N,1)))))/(N_batch * torch.sum(q))
            loss_d = -loss_d + torch.mean(loss_xy) + torch.mean(torch.square(compiled_netD.Barron_norm(X_train)))
            optimizerD.step()

            # Train generator later, in small timescale.
            compiled_netG.zero_grad()
            z = torch.randn(N_batch,dz,device=device)
            xp = np.random.multivariate_normal([0] * dx, np.identity(dx) ,N_batch)
            x = np.random.multivariate_normal([0] * dx, (sigma ** 2) * np.identity(dx), N_batch) + xp
            xp = torch.from_numpy(xp).to(device).float()
            x = torch.from_numpy(x).to(device).float()
            loss_g = torch.mean(compiled_netD(compiled_netG(z,x),xp))
            optimizerG.step()

        else:
            # cGAN
            # Train discriminator first, in large timescale.
            compiled_netD.zero_grad()
            z = torch.randn(N_batch,dz,device=device)
            loss_d = torch.mean(compiled_netD(compiled_netG(z,X_train).detach(),X_train))
            loss_xy = torch.mean(compiled_netD(Y_train,X_train))
            loss_d = -loss_d + loss_xy
            optimizerD.step()

            # Train generator later, in small timescale.
            compiled_netG.zero_grad()
            z = torch.randn(N_batch,dz,device=device)
            loss_g = torch.mean(compiled_netD(compiled_netG(z,X_train),X_train))
            optimizerG.step()

    # For each iteration, calculate the (projected) generalization error
    test_x = 1024
    test_z = 1024

    w2d = np.zeros(test_x)
    x_test = torch.randn((test_x,dx))
    for i in range(test_x):
        z = torch.randn(test_z,dz,device=device)
        x_g = x_test[[i],:].view(1,dx).expand(test_z,dx)
        dist_g = compiled_netG(z,x_g)
        dist_t = torch.sigmoid(x_g) * 0.5 + torch.rand(test_z,dy,device=device) * 0.5
        w2d = calculate_2_wasserstein_dist(dist_t, dist_g)
    gerror[epoch] = np.mean(w2d)


# Save gerror array to the file.
filename = "gerror_{}_{}_{}".format(dy,N,seed)
if VE:
    filename = filename + "_VE"
filename = filename + ".npz"
np.savez(filename,gerror=gerror)