# basic functions
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import manifold, datasets
from sklearn.model_selection import train_test_split
# torch functions
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
# local functions
from tr import tr, HSIC_in, cor

class Discriminator(nn.Module):
    def __init__(self, ndim):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(ndim, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, X):
        out = self.model(X)
        return out

def npLoader(Loader, net, device):
    # obtain the features and corresponding targets after representation learning
    X, y = next(iter(Loader))
    mb_size = X.shape[0]
    X = net(X.to(device))[0].cpu().detach().numpy()
    y = y.numpy()
    torch.cuda.empty_cache()
    for step, (X_t, y_t) in enumerate(Loader):
        X_t = net(X_t.to(device))[0].cpu().detach().numpy()
        y_t = y_t.numpy()
        X = np.concatenate((X, X_t))
        y = np.concatenate((y, y_t))
        torch.cuda.empty_cache()
    return X[mb_size:], y[mb_size:]
    
def train(args, epoch, R_net, D_net, trainLoader, optimizer_R, optimizer_D, trainF, device):
    R_net.train()
    D_net.train()
    adversarial_loss = torch.nn.BCELoss()
    for batch_idx, (data, target) in enumerate(trainLoader):
        ones_label = Variable(torch.ones(data.shape[0], 1).to(device))
        zeros_label = Variable(torch.zeros(data.shape[0], 1).to(device))
        z = torch.randn(data.shape[0], args.latent_dim).to(device)
        z = Variable(torch.div(z,torch.t(torch.norm(z,p='fro',dim=1).repeat(args.latent_dim, 1))).to(device))
        data = Variable(data.to(device))
        
        # update Reducer
        optimizer_R.zero_grad()
        target_onehot = Variable((target.view(-1,1)).to(device))
        target = Variable(target.to(device))
        latent, output = R_net(data)
        G_loss = adversarial_loss(D_net(latent), zeros_label)*args.lr
        tr_loss = tr(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device) # Trace loss on latent space
        tr_loss_og = tr(data.view(data.shape[0],-1), data.view(data.shape[0],-1), target_onehot, data.shape[0], device) # Trace loss on original data
        HSIC = HSIC_in(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device) # Conditional HSIC
        R_loss = tr_loss + G_loss
        R_loss.backward()
        optimizer_R.step()
        
        # update Discriminator
        optimizer_D.zero_grad()
        D_real = D_net(latent.detach())
        D_fake = D_net(z)
        D_loss_real = adversarial_loss(D_real, ones_label)
        D_loss_fake = adversarial_loss(D_fake, zeros_label)
        D_loss = (D_loss_real + D_loss_fake)/2.*args.lr
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()
        
        pred = output.data.max(1)[1]
        incorrect = pred.ne(target.data).cpu().sum()
        err = torch.tensor(100.)*incorrect/len(data)
    trainF.write('{},{},{},{}\n'.format(epoch, HSIC, tr_loss_og, tr_loss)) # log trace loss and conditional HSIC
    trainF.flush()
    print('Train Epoch: {}, tr_loss_latent: {:.4f}, tr_loss_og: {:.4f}, HSIC: {:.4f}, VG: {:.4f}, D: {:.4f}'.format(
        epoch, tr_loss, tr_loss_og, HSIC, G_loss, D_loss))

def test(args, epoch, R_net, testLoader, optimizer, testF, device):
    R_net.eval()
    tr_loss = 0
    tr_loss_og = 0
    HSIC = 0
    DC = 0
    with torch.no_grad():
        for data, target in testLoader:
            data = Variable(data.to(device))
            target_onehot = Variable((target.view(-1,1)).to(device))
            target = Variable(target.to(device))
            latent, output = R_net(data)
            tr_loss += tr(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device)
            tr_loss_og += tr(data.view(data.shape[0],-1), data.view(data.shape[0],-1), target_onehot, data.shape[0], device)
            HSIC += HSIC_in(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device)
    tr_loss /= len(testLoader)
    tr_loss_og /= len(testLoader)
    HSIC /= len(testLoader)
    DC /= len(testLoader)
    print('Test set:  tr_loss_latent: {:.4f}, tr_loss_og: {:.4f}, tr_loss: {:.4f}, HSIC: {:.4f}'.format(
         tr_loss, tr_loss_og, tr_loss-tr_loss_og, HSIC))
    testF.write('{},{},{},{},{}\n'.format(epoch, DC, tr_loss_og, tr_loss, tr_loss-tr_loss_og))
    testF.flush()
    
def s_curve(n_points, args):
    X, y = datasets.make_s_curve(n_points, random_state=0)

    # Add 3d scatter plot
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(1,1,1, projection='3d')
    ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap=plt.cm.Spectral)
    plt.savefig(os.path.join(args.save, 'original.png'),dpi=30)
    plt.show()

    # Project on 400-dimensional space
    X = np.matmul(np.random.rand(20*20, 3), X.T).T
    X = X.reshape(n_points, 1, 20, 20)
    indices = np.arange(n_points)
    X_train, X_test, y_train, y_test, idx1, idx2 = train_test_split(X, y, indices, test_size=0.1, random_state=1)
    X_train = X_train.astype(np.float32)
    y_train = y_train.astype(np.float32)
    X_test = X_test.astype(np.float32)
    y_test = y_test.astype(np.float32)
    return X_train, X_test, y_train, y_test, idx1, idx2