import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *

import math
from tqdm import tqdm

import time
import os
import json
import pathlib

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sb

from fairx.metrics import FairnessUtils, DataUtilsMetrics

from sklearn.model_selection import train_test_split

from torch.utils.tensorboard import SummaryWriter

from module import *
import argparse



def visualize(ori_data, fake_data, epoch, direc, writer):
    
    ori_data = np.asarray(ori_data)

    fake_data = np.asarray(fake_data)
    
    ori_data = ori_data[:fake_data.shape[0]]
    
    sample_size = 120
    
    idx = np.random.permutation(len(ori_data))[:sample_size]
    
    randn_num = np.random.permutation(sample_size)[:1]
    
    real_sample = ori_data[idx]

    fake_sample = fake_data[idx]
    
    real_sample_2d = real_sample
    
    fake_sample_2d = fake_sample
    
    
    mode = 'visualization'
        

        
    ### PCA
    
    pca = PCA(n_components=2)
    pca.fit(real_sample_2d)
    pca_real = (pd.DataFrame(pca.transform(real_sample_2d))
                .assign(Data='Baseline'))
    pca_synthetic = (pd.DataFrame(pca.transform(fake_sample_2d))
                     .assign(Data='Distilled'))
    pca_result = pca_real._append(pca_synthetic).rename(
        columns={0: '1st Component', 1: '2nd Component'})
    
    
    ### TSNE
    
    tsne_data = np.concatenate((real_sample_2d,
                            fake_sample_2d), axis=0)

    tsne = TSNE(n_components=2,
                verbose=0,
                perplexity=40)
    tsne_result = tsne.fit_transform(tsne_data)
    
    
    tsne_result = pd.DataFrame(tsne_result, columns=['X', 'Y']).assign(Data='Baseline')
    
    tsne_result.loc[len(real_sample_2d):, 'Data'] = 'Distilled'
    
    fig, axs = plt.subplots(ncols = 2, nrows=1, figsize=(10, 5))

    sb.scatterplot(x='1st Component', y='2nd Component', data=pca_result,
                    hue='Data', style='Data', ax=axs[0])
    sb.despine()
    
    axs[0].set_title('PCA Result')


    sb.scatterplot(x='X', y='Y',
                    data=tsne_result,
                    hue='Data', 
                    style='Data', 
                    ax=axs[1])
    sb.despine()

    axs[1].set_title('t-SNE Result')

    fig.suptitle('Assessing Diversity: Qualitative Comparison of Baseline and Distilled Data Distributions', 
                 fontsize=14)
    fig.tight_layout()
    fig.subplots_adjust(top=.88)
    
    plt.savefig(os.path.join(f'{direc}', f'{time.time()}-tsne-result-{epoch}.png'))

    writer.add_figure(mode, fig, epoch)



def main(args):

    metrics = []

    dataset_name = args.dataset_name

    arch = 'Fair4Free'

    file_name = f'{dataset_name}-distillation-{arch}-tiny'

    folder_name = f'saved_files/{time.time():.4f}-{file_name}'

    pathlib.Path(folder_name).mkdir(parents=True, exist_ok=True) 

    ae_fig_dir_path = f'{folder_name}/output/ae'

    pathlib.Path(ae_fig_dir_path).mkdir(parents=True, exist_ok=True) 

    writer = SummaryWriter(log_dir = folder_name, comment = f'{file_name}', flush_secs = 45)


    train_data, test_data, D = load_dataset(dataset_name)
    S_train, S_test = train_data.S.numpy(), test_data.S.numpy()
    Y_train, Y_test = train_data.Y.numpy(), test_data.Y.numpy()

    batch_size = 2048
    epochs = 5000
    verbose = 100

    lr = 1e-3
    x_dim = train_data.X.shape[1]
    s_dim = train_data.S.max().item()+1
    h_dim = 64
    h_dim_tiny = 32 ## For smaller distillation model
    z_dim = 8
    noise_dim = 64

    device = 'cuda'



    ### Fair Representation training

    lg_beta = 6
    beta = 10 ** lg_beta
    fair_model = FairDisCo(x_dim, h_dim, z_dim, s_dim, D)

    if os.path.exists('./model/FairDisCo_{}_{}.pkl'.format(dataset_name, lg_beta)):
        fair_model.load('./model/FairDisCo_{}_{}.pkl'.format(dataset_name, lg_beta))
        fair_model.eval()
        print(f'Eval Started for beta {lg_beta}')

    else:
        fair_model.fit(train_data=train_data, epochs=epochs, lr=lr, batch_size=batch_size, verbose=verbose, beta=beta, device=device)
        torch.save(fair_model.state_dict(), './model/FairDisCo_{}_{}.pkl'.format(dataset_name, lg_beta))
        fair_model.eval()

        
        print(f'Fair Representation Training done!')

    
    ### Fair Distillation using smaller model

    base_model = DiscoFrisko(x_dim, h_dim_tiny, z_dim, 64, D)

    print(f'Distillation Started')


    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

    base_model.train()

    optimm = torch.optim.Adam(base_model.parameters(), lr = 1e-5)

    base_model = base_model.to(device)

    fair_model = fair_model.to(device)

    fair_model.eval()

    criterion = torch.nn.L1Loss()

    for it in tqdm(range(epochs+1)):

        train_loss = 0.0
        
        for i, (x, c, s, y) in enumerate(train_loader):

            x = x.to(device)

            y = y.to(device)

            s = s.to(device)

            batch_size = x.shape[0]

            base_model.zero_grad()

            noise = torch.randn(batch_size, noise_dim, dtype = torch.float, device = device)

            student_enc = base_model.encode(noise)

            with torch.no_grad():

                teacher_out = fair_model.encode(x, s)


            loss = criterion(student_enc, teacher_out) +  base_model.calculate_kl()


            loss.backward()

            optimm.step()

            train_loss += loss.item()

                
        writer.add_scalar('Loss', train_loss, it)

        if it%100==0:

            visualize(teacher_out.detach().cpu(), student_enc.detach().cpu(), it, ae_fig_dir_path, writer)

            print(f'loss: {train_loss}')

    print(f'Distillation complete')

    torch.save(base_model.state_dict(), f'./{folder_name}/distilled_model_{dataset_name}_{lg_beta}.pkl')

    logs = []

    base_model.eval()

    ## Evaluation and Generated Samples
    
    with torch.no_grad():
        
        noise_train = torch.randn(len(S_train), noise_dim, dtype = torch.float, device = device)

        noise_test = torch.randn(len(S_test), noise_dim, dtype = torch.float, device = device)

        z_train = base_model.encode(noise_train)

        z_test = base_model.encode(noise_test)

        synthetic_samples = fair_model.decode(z_test, test_data.S.to(device))

        np.save(f'{folder_name}/synthetic_samples.npy',synthetic_samples.detach().cpu().numpy())

    fair_utils = FairnessUtils((z_train.detach().cpu().numpy(), z_test.detach().cpu().numpy(), Y_train, Y_test, S_train, S_test))

    fair_utils_res = fair_utils.evaluate_fairness()

    data_utils = DataUtilsMetrics((z_train.detach().cpu().numpy(), z_test.detach().cpu().numpy(), Y_train, Y_test, S_train, S_test))

    data_utils_res = data_utils.evaluate_utility()

    data_utils_res.update(fair_utils_res)

    data_utils_res['Methods'] = 'Distillation'

    print(data_utils_res)

    res_df = pd.DataFrame([data_utils_res])

    res_df.to_csv(f'./{folder_name}/fairx-res_distil_{dataset_name}.csv', index=False)


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--dataset_name',
        choices=['Adult-sex','Adult-race', 'Compas-sex', 'Compas-race'],
        default='Adult-sex',
        type=str)

    args = parser.parse_args() 

    main(args)

