import torch
import torch.nn as nn
# import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import seaborn as sns
import pandas as pd
from tr import tr, HSIC_in, cor

class Discriminator(nn.Module):
    def __init__(self, ndim):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(ndim, 32),
            nn.LeakyReLU(0.2),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, X):
        out = self.model(X)
        return out

def to_onehot(target):
    # change the target to one-hot version
    Y = np.ravel(target.numpy()).astype(int)
    Y_train = np.zeros((Y.shape[0], Y.max()-Y.min()+1))
    Y_train[np.arange(Y.size), Y-Y.min()] = 1
    target_onehot =torch.from_numpy(Y_train.astype(np.float32))
    return target_onehot

def npLoader(Loader, net, device):
    # obtain the features and corresponding targets after dimension reduction
    X, y = next(iter(Loader))
    mb_size = X.shape[0]
    X = net(X.to(device))[0].cpu().detach().numpy()
    y = y.numpy()
    torch.cuda.empty_cache()
    for step, (X_t, y_t) in enumerate(Loader):
        X_t = net(X_t.to(device))[0].cpu().detach().numpy()
        y_t = y_t.numpy()
        X = np.concatenate((X, X_t))
        y = np.concatenate((y, y_t))
        torch.cuda.empty_cache()
    return X[mb_size:], y[mb_size:]
    
def train(args, epoch, R_net, D_net, trainLoader, optimizer_R, optimizer_D, trainF, f, device):
    R_net.train()
    D_net.train()
    adversarial_loss = torch.nn.BCELoss()
    for batch_idx, (data, target) in enumerate(trainLoader):
        ones_label = Variable(torch.ones(data.shape[0], 1).to(device))
        zeros_label = Variable(torch.zeros(data.shape[0], 1).to(device))
        z = torch.randn(data.shape[0], args.latent_dim)
        z = Variable(torch.div(z,torch.t(torch.norm(z,p='fro',dim=1).repeat(args.latent_dim, 1))).to(device))
        data = Variable(data.to(device))
        
        # update Reducer
        optimizer_R.zero_grad()
        target_onehot = Variable(to_onehot(target).to(device))
        target = Variable(target.to(device))
        latent, output = R_net(data)
        G_loss = adversarial_loss(D_net(latent), zeros_label)*args.lr
        tr_loss = tr(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device) # Trace loss on latent space
        tr_loss_og = tr(data.view(data.shape[0],-1), data.view(data.shape[0],-1), target_onehot, data.shape[0], device) # Trace loss on original data
        HSIC = HSIC_in(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device) # Conditional HSIC
        R_loss = G_loss + tr_loss
        R_loss.backward()
        optimizer_R.step()
        # update Discriminator
        optimizer_D.zero_grad()
        D_real = D_net(latent.detach())
        D_fake = D_net(z)
        D_loss_real = adversarial_loss(D_real, ones_label)
        D_loss_fake = adversarial_loss(D_fake, zeros_label)
        D_loss = (D_loss_real + D_loss_fake)/2.
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()
        
        pred = output.data.max(1)[1]
        incorrect = pred.ne(target.data).cpu().sum()
        err = torch.tensor(100.)*incorrect/len(data)
        if batch_idx%100==0:
            f.write('Train iter: {},  loss: {:.6f}, tr_loss: {:.6f}, G_loss: {:.6f}\n'.format(
            batch_idx, R_loss, tr_loss, G_loss))
            f.flush()
    trainF.write('{},{},{},{}\n'.format(epoch, HSIC, tr_loss_og, tr_loss)) # log trace loss and conditional HSIC
    trainF.flush()

def test(args, epoch, R_net, testLoader, optimizer, testF, f, device):
    R_net.eval()
    tr_loss = 0
    tr_loss_og = 0
    HSIC = 0
    DC = 0
    with torch.no_grad():
        for data, target in testLoader:
            data = Variable(data.to(device))
            target_onehot = Variable(to_onehot(target).to(device))
            target = Variable(target.to(device))
            latent, output = R_net(data)
            tr_loss += tr(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device)
            tr_loss_og += tr(data.view(data.shape[0],-1), data.view(data.shape[0],-1), target_onehot, data.shape[0], device)
            HSIC += HSIC_in(data.view(data.shape[0],-1), latent, target_onehot, data.shape[0], device)
            DC += cor(latent, target_onehot, data.shape[0], device)
    tr_loss /= len(testLoader)
    tr_loss_og /= len(testLoader)
    HSIC /= len(testLoader)
    DC /= len(testLoader)
    f.write('Epoch: {}, DC: {:.4f}, loss: {:.4f}\n'.format(
        epoch, DC, tr_loss))
    testF.write('{},{},{},{},{}\n'.format(epoch, DC, HSIC, tr_loss_og, tr_loss))
    testF.flush()
    f.flush()