import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *

import math
from tqdm import tqdm

import time
import os
import json
import pathlib

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sb

from sklearn.model_selection import train_test_split

from torch.utils.tensorboard import SummaryWriter

from torchvision.utils import make_grid
import torchvision.transforms.functional as FF

import torchvision
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
import torchvision.utils as vutils
from PIL import Image
from tqdm import tqdm

import torch.nn.functional as F
from torch.utils.data import Dataset

from module import *

import argparse

from scipy.special import digamma
from sklearn.neighbors import NearestNeighbors, KDTree
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import roc_auc_score
from sklearn.utils import shuffle


def show(imgs, img_name, save_fig = False):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize = (15, 9))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = FF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if save_fig:
        plt.savefig(f'{img_name}.png', dpi=300, transparent = False)

class CelebaLoader():

    """
    Dataset loader for CelebA dataset [1].

    [1] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE international conference on computer vision. 2015.
    """

    def __init__(self, data_dir, target = 'Male', sensitive_attr = 'Eyeglasses'):

        """
        Input: data_dir, string, path to the dataset directory. Download the dataset from here: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, and check the following link for details: https://www.kaggle.com/datasets/jessicali9530/celeba-dataset.

            target: string

            sensitive_attr: string, protected attribute

        Here, we use `Eyeglasses` as sensitive attribute and `Gender` as target.

        Return, Numpy arrays of data_x (features), data_y (target), data_s (sensitive_attribute)

        """

        super().__init__()

        self.data_dir = data_dir

        self.target = target

        self.sensitive_attr = sensitive_attr

    def create_celebA(self):

        data_X=[]
        
        data_y=[]
        
        data_s=[]
        
        attribute_file = pd.read_csv(f'{self.data_dir}/list_attr_celeba.csv')
        
        transform = transforms.Compose([transforms.Resize(64),
                            transforms.CenterCrop(64),
                            transforms.ToTensor()])

        print(f'Dataset building started')
        
        for index, row in tqdm(attribute_file.iterrows()):
            
            image = Image.open(f'{self.data_dir}/img_align_celeba/img_align_celeba/'+row['image_id'])
    
            tensor = transform(image)
            
            data_X.append(tensor.numpy())
            
            data_y.append(int(row[self.target]))
            
            data_s.append(int(row[self.sensitive_attr]))

        print(f'Dataset Ready')
            
        return np.array(data_X), np.array(data_y), np.array(data_s)


class VectorDataset(Dataset):
    """
    Helper function, adapted from https://github.com/SoftWiser-group/FairDisCo/blob/main/utils.py
    """
    def __init__(self, X, S, Y):
        
        self.X = X
        self.S = S
        self.Y = Y

    def __getitem__(self, i):
        
        x, s, y = self.X[i], self.S[i], self.Y[i]
        
        return x, s, y
    
    def __len__(self):
        
        return self.X.shape[0]


def visualize(ori_data, fake_data, epoch, direc, writer):
    
    ori_data = np.asarray(ori_data)

    fake_data = np.asarray(fake_data)
    
    ori_data = ori_data[:fake_data.shape[0]]
    
    sample_size = 120
    
    idx = np.random.permutation(len(ori_data))[:sample_size]
    
    randn_num = np.random.permutation(sample_size)[:1]
    
    real_sample = ori_data[idx]

    fake_sample = fake_data[idx]
    
    real_sample_2d = real_sample
    
    fake_sample_2d = fake_sample
    
    mode = 'visualization'
        

        
    ### PCA
    
    pca = PCA(n_components=2)
    pca.fit(real_sample_2d)
    pca_real = (pd.DataFrame(pca.transform(real_sample_2d))
                .assign(Data='Baseline'))
    pca_synthetic = (pd.DataFrame(pca.transform(fake_sample_2d))
                     .assign(Data='Distilled'))
    pca_result = pca_real._append(pca_synthetic).rename(
        columns={0: '1st Component', 1: '2nd Component'})
    
    
    ### TSNE
    
    tsne_data = np.concatenate((real_sample_2d,
                            fake_sample_2d), axis=0)

    tsne = TSNE(n_components=2,
                verbose=0,
                perplexity=40)
    tsne_result = tsne.fit_transform(tsne_data)
    
    
    tsne_result = pd.DataFrame(tsne_result, columns=['X', 'Y']).assign(Data='Baseline')
    
    tsne_result.loc[len(real_sample_2d):, 'Data'] = 'Distilled'
    
    fig, axs = plt.subplots(ncols = 2, nrows=1, figsize=(10, 5))

    sb.scatterplot(x='1st Component', y='2nd Component', data=pca_result,
                    hue='Data', style='Data', ax=axs[0])
    sb.despine()
    
    axs[0].set_title('PCA Result')


    sb.scatterplot(x='X', y='Y',
                    data=tsne_result,
                    hue='Data', 
                    style='Data', 
                    ax=axs[1])
    sb.despine()

    axs[1].set_title('t-SNE Result')


    fig.suptitle('Assessing Diversity: Qualitative Comparison of Baseline and Distilled Data Distributions', 
                 fontsize=14)
    fig.tight_layout()
    fig.subplots_adjust(top=.88)
    
    plt.savefig(os.path.join(f'{direc}', f'{time.time()}-tsne-result-{epoch}.png'))

    writer.add_figure(mode, fig, epoch)


def main(args):

    dataset_name = args.dataset_name
    sensitive_attr = args.sensitive_attr


    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    noise_dim = 64
    
    if dataset_name == 'celeba-rgb':

        data_dir = f'/add/the/dataset/here/'

        data_loader = CelebaLoader(data_dir = data_dir, sensitive_attr = sensitive_attr)

        x, y, s = data_loader.create_celebA()
        
        labelEncY = LabelEncoder()
        
        y = labelEncY.fit_transform(y)
        
        labelEncS = LabelEncoder()
        
        s = labelEncS.fit_transform(s)
        
        dataset = VectorDataset(x, s, y)
        
        
        train_data, test_data = train_test_split(dataset, test_size = 0.2, shuffle= True)
        
        verbose = 100
        batch_size = 2048
        lr = 1e-5
        
        n_chan = 3
        z_dim = 10
        s_dim = 2
        
        fairdisco_epochs= 5000
    
    else:
    
        train_data, test_data = loadMnist(color=True)
    
        verbose = 100
        batch_size = 4096
        lr = 1e-5
        
        n_chan = 3
        z_dim = 10
        s_dim = 3
        
        fairdisco_epochs= 5000
        
    
    
    metrics = []
    
    epochs = 5000
    loss_func = 'L1'
    
    arch = 'Fair4Free'
    
    file_name = f'{dataset_name}-distillation-{arch}-{loss_func}-sen-{sensitive_attr}-tiny'
    
    folder_name = f'saved_files/{time.time():.4f}-{file_name}'
    
    pathlib.Path(folder_name).mkdir(parents=True, exist_ok=True) 
    
    ae_fig_dir_path = f'{folder_name}/output/ae'
    
    pathlib.Path(ae_fig_dir_path).mkdir(parents=True, exist_ok=True) 
    
    writer = SummaryWriter(log_dir = folder_name, comment = f'{file_name}', flush_secs = 45)
    
    
    ### Fair Representation training
    
    beta = 0
    
    if dataset_name == 'celeba-rgb':
        
        fair_model = FairDisCoImage(z_dim=z_dim, s_dim=s_dim, n_chan=n_chan)
        if os.path.exists(f'./model/FairDisCoImage_celeba_sen_{sensitive_attr}.pkl'):
            print(f'model loading')
            fair_model.load(f'./model/FairDisCoImage_celeba_sen_{sensitive_attr}.pkl')
            fair_model.eval()
            print(f'Pre-trained Model Loaded!')
        else:
            print(f'model training')
            fair_model.fit(train_data=train_data, epochs=fairdisco_epochs, lr=lr,  batch_size=batch_size, verbose=verbose, beta=beta, device=device)
            torch.save(fair_model.state_dict(), f'./model/FairDisCoImage_celeba_sen_{sensitive_attr}.pkl')
        
            
            print(f'Pretraining done!')
    
    else:
    
        fair_model = FairDisCoImage(z_dim=z_dim, s_dim=s_dim, n_chan=n_chan)
        if os.path.exists('./model/FairDisCoImage_color.pkl'):
            print(f'model loading')
            fair_model.load('./model/FairDisCoImage_color.pkl')
            fair_model.eval()
            print(f'Pre-trained Model Loaded!')
        else:
            print(f'model training')
            fair_model.fit(train_data=train_data, epochs=fairdisco_epochs, lr=lr,  batch_size=batch_size, verbose=verbose, beta=beta, device=device)
            torch.save(fair_model.state_dict(), './model/FairDisCoImage_color.pkl')      
            print(f'Pretraining done!')
    

    ### Using smaller model for Distillation
    
    base_model = FairDisCoImage_small(z_dim=z_dim, s_dim=s_dim, n_chan=n_chan)
    
    print(f'Distillation Started')
    
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
    
    data_batch_test = next(iter(test_loader))
    
    test_x, test_s, test_y = data_batch_test
    
    base_model.train()
    
    optimm = torch.optim.Adam(base_model.parameters(), lr = lr,  betas = (0.9, 0.999), weight_decay = 0.0001)
    
    base_model = base_model.to(device)
    
    fair_model = fair_model.to(device)

    fair_model.eval()
    
    criterion = torch.nn.L1Loss()
    
    for it in tqdm(range(epochs+1)):
    
        train_loss = 0.0
        
        for i, data in enumerate(train_loader):
    
            x, s, _ = data
    
            x, s = x.to(device), s.to(device)
    
            batch_size = x.shape[0]
    
            base_model.zero_grad()

            noise = torch.randn(batch_size, n_chan, noise_dim, noise_dim, dtype = torch.float, device = device)
    
            student_enc = base_model.encode_small(noise)
    
            with torch.no_grad():
    
                teacher_out = fair_model.encode(x)
    
            loss = criterion(student_enc, teacher_out) +  base_model.calculate_kl()
    
            loss.backward()
    
            optimm.step()

            train_loss += loss.item()
                
        writer.add_scalar('Loss', train_loss, it)
    
        if it%100==0:
    
            visualize(teacher_out.detach().cpu(), student_enc.detach().cpu(), it, ae_fig_dir_path, writer)


            ## Generating Samples

            
            with torch.no_grad():

                noise = torch.randn(256, n_chan, noise_dim, noise_dim, dtype = torch.float, device = device)
    
                reaL_test_data_encoded = base_model.encode_small(noise)
    
                out_gen2 = fair_model.decode(reaL_test_data_encoded.to(device), torch.ones(256).long().to(device))
    
                out_gen4 = fair_model.decode(reaL_test_data_encoded, torch.zeros(256).long().to(device))
    
                sample_imgs = fair_model.decode(reaL_test_data_encoded, test_s[:256].to(device))
    
                grid = make_grid(sample_imgs[:64])
    
                without_sensitive = make_grid(out_gen2[:64])
    
                with_sensitive = make_grid(out_gen4[:64])
    
            writer.add_image('Sample reconstructed images', grid, it)
    
            writer.add_image(f'Without {sensitive_attr}', with_sensitive, it)
    
            writer.add_image(f'With {sensitive_attr}', without_sensitive, it)
    
            print(f'loss: {train_loss}')
    
    print(f'Distillation complete')

    torch.save(base_model.state_dict(), f'./{folder_name}/distilled_model_{dataset_name}.pkl')

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--dataset_name',
        choices=['celeba-rgb','cmnist'],
        default='cmnist',
        type=str)

    parser.add_argument(
        '--sensitive_attr',
        choices=['Eyeglasses','No_Beard', 'Smiling', 'Wearing_Hat', 'Wearing_Necklace', 'Narrow_Eyes', 'Gray_Hair', 'rgb'],
        default='No_Beard',
        type=str)

    args = parser.parse_args() 

    main(args)

