

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score

import torch as th
import torch.utils.data
from tqdm import tqdm

from utils.model_plot import plot_grid, plot_global_feature
from src.model_clf import XTab as XTab_clf
from utils.arguments import get_arguments, get_config
from utils.arguments import print_config_summary
from utils.eval_utils import linear_model_eval, plot_clusters, append_tensors_to_lists, concatenate_lists, aggregate
from utils.load_data import Loader
from utils.utils import set_dirs, run_with_profiler, update_config_with_model_dims

torch.manual_seed(1)

import seaborn as sns
sns.set(rc={'figure.dpi':100, 'savefig.dpi':100, 'figure.figsize':(3,2), 'axes.facecolor':'white', 'figure.facecolor':'white'}, font_scale=1.0) 
sns.set_style(style='white') 

def eval(data_loader, config, mask_g=None):
    """Wrapper function for evaluation.

    Args:
        data_loader (IterableDataset): Pytorch data loader.
        config (dict): Dictionary containing options and arguments.

    """
    # Instantiate Autoencoder model
    model = XTab_clf(config)
    # Load the model
    model.load_models()
    
    # Evaluate Autoencoder
    with th.no_grad():
        # Get the joint embeddings and class labels of training set
        z_train, y_train, feature_imp_tr, _ = evalulate_models(data_loader, model, config, plot_suffix="training", mode="train", mask_g_given=mask_g)
        
        test_mode = 'test' if config["test_mode"] else 'validation'
        # Train and evaluate logistig regression using the joint embeddings of training and test set
        _, _, feature_imp_te, acc, auc, instancewise_ranking_te, feature_imp_te_mf = evalulate_models(data_loader, model, config, 
                                                plot_suffix="test", mode=test_mode, z_train=z_train, y_train=y_train, mask_g_given=mask_g)
        
        # End of the run
        print(f"Evaluation results are saved under ./results/{config['framework']}/evaluation/\n")
        print(f"{100 * '='}\n")
        
        return (feature_imp_tr, feature_imp_te, acc, auc, instancewise_ranking_te, feature_imp_te_mf)



def evalulate_models(data_loader, model, config, plot_suffix="_Test", mode='train', z_train=None, y_train=None, mask_g_given=None):
    """Evaluates representations using linear model, and visualisation of clusters using t-SNE and PCA on embeddings.

    Args:
        data_loader (IterableDataset): Pytorch data loader.
        model (object): Class that contains the encoder and associated methods
        config (dict): Dictionary containing options and arguments.
        plot_suffix (str): Suffix to be used when saving plots
        mode (str): Defines whether to evaluate the model on training set, or test set.
        z_train (ndarray): Optional numpy array holding latent representations of training set
        y_train (list): Optional list holding labels of training set

    Returns:
        (tuple): tuple containing:
            z_train (numpy.ndarray): Numpy array holding latent representations of data set
            y_train (list): List holding labels of data set

    """
    # A small function to print a line break on the command line.
    break_line = lambda sym: f"{100 * sym}\n{100 * sym}\n"
    
    # Print whether we are evaluating training set, or test set
    decription = break_line('#') + f"Getting the joint embeddings of {plot_suffix} set...\n" + \
                 break_line('=') + f"Dataset used: {config['dataset']}\n" + break_line('=')
    
    # Print the message         
    print(decription)
    
    # Get the model
    encoder = model.encoder
    mask = model.mask
    clf = model.clf
    gate = model.gate
    if mask_g_given is None:
        mask_g = model.mask_g 
    else:
        print("Using given mask g ==========================================")
        mask_g = mask_g_given

    # Move the model to the device
    encoder.to(config["device"])
    mask.to(config["device"])
    mask_g.to(config["device"])
    gate.to(config["device"])

    # Set the model to evaluation mode
    encoder.eval()
    mask.eval()
    mask_g.eval()
    clf.eval()
    gate.eval()

    # Choose either training, or test data loader
    if mode == 'train':
        data_loader_tr_or_te = data_loader.train_loader  
    elif mode == 'validation':
        data_loader_tr_or_te = data_loader.validation_loader  
    else: 
        data_loader_tr_or_te = data_loader.test_loader

    # Attach progress bar to data_loader to check it during training. "leave=True" gives a new line per epoch
    train_tqdm = tqdm(enumerate(data_loader_tr_or_te), total=len(data_loader_tr_or_te), leave=True)

    # Create empty lists to hold data for representations, and class labels
    z_l, clabels_l = [], []
    mask_l, mask_l2, mask_l3, x_l = [], [], [], []
    clf_preds_l, labels_oh_l = [], []

    # Go through batches
    for i, (x, label) in train_tqdm:

        # Generate mask based on x
        xt = model._tensor(x)
        mask_x = mask(xt)
        
        if config["use_mask_g"]:
#             mgt = mask_g(xt)
#             g = gate(xt)
#             mask_e = g*mgt + (1-g)*mask_x
#             mask_e = mask_e*mask_e

            mgt = mask_g(xt)
            mask_e = mgt + mask_x
            mask_e = mask_e/max([1.0, mask_e.max()])
            mask_e = mask_e*mask_e
            
            #mask_e = mgt * mgt
            
            print("Global * Local")
        else:
            mask_e = mask_x*mask_x 

        mask_gg = mask_g(xt) * mask_g(xt)
        
        #### DELETE after use
#         if mask_g_given is not None:
#             print("Using given mask g to generate mask_e ============================")
#             mgt = mask_g(xt)
#             mask_e = mgt * mgt
        #####################
        
        
        # Collect mask and data
        mask_l.append(mask_e.cpu().numpy())
        mask_l2.append(mask_gg.cpu().numpy())
        mask_l3.append(mask_gg.cpu().numpy())

        x_l.append(x)
        
#         if i in list(range(50)):
#             filename = "Local_Fold" + str(config["fold_num"]) + "_" + mode + str(i)
#             plot_grid(x.cpu().numpy(), mask_l[-1], label.cpu().numpy(), config, filename + "_used" )
#             plot_grid(x.cpu().numpy(), mask_l2[-1], label.cpu().numpy(), config, filename + "fromGlobalMask_")
#             plot_grid(x.cpu().numpy(), mask_l3[-1], label.cpu().numpy(), config, filename + "_global_only")
            
        # Generate subsets with added noise -- labels are not used
        #x_tilde_list, submask_list = model.subset_generator(x, mask_x)
        x_tilde_list, submask_list = model.subset_generator(x, mask_e)

        latent_list = []

        # Extract embeddings (i.e. latent) for each subset
        for xi, me in zip(x_tilde_list, submask_list):
            me = model._tensor(me)
            # Turn xi to tensor, and move it to the device. Note me is already mi*mi (i.e. mask_e)
            xi = me * model._tensor(xi)
            # Apply mask to xi to choose important features
            # xi = mi * xi
            # Extract latent
            _, latent, _ = encoder(xi)
            # Collect latent
            latent_list.append(latent)

            
        # Aggregation of latent representations
        latent = aggregate(latent_list, config)
        
        
        # Predict class of the samples by using their joint embedding
        clf_preds = clf(latent)
                        
        # Append tensors to the corresponding lists as numpy arrays
        z_l, clabels_l, clf_preds_l = append_tensors_to_lists([z_l, clabels_l, clf_preds_l],
                                                 [latent, label.int(), clf_preds])

    # Turn list of numpy arrays to a single numpy array for representations.
    z, mask_arr, mask_arr2, x_arr, clf_preds_arr = concatenate_lists([z_l, mask_l, mask_l2, x_l, clf_preds_l])
    
    # Turn list of numpy arrays to a single numpy array for class labels.
    clabels = concatenate_lists([clabels_l])
    
    

    
    # Accuracy using clf
    clf_preds_list = np.argmax(clf_preds_arr, axis=1).tolist()
    acc = accuracy_score(clabels, clf_preds_list)
    auc = roc_auc_score(clabels, clf_preds_list) if len(list(set(clabels)))==2 else 0
    ap = average_precision_score(clabels, clf_preds_list) if len(list(set(clabels)))<=2 else 0
    
    print(f"Classifier acc: {acc}, auc: {auc}, ap: {ap}")    
    
    if mask_g_given is not None:
        filename = config["dataset"] + "_Global_Fold" + str(config["fold_num"]) + "_" + config["dataset"] + mode
        fn2 = config["dataset"] + "_GlobalMask_"
    else:
        filename = "PostGlobal_Fold" + str(config["fold_num"]) + "_" + config["dataset"] + mode
        fn2 = "PostGlobalMask_"

    feature_importance2 = plot_global_feature(x_arr, mask_arr2, clabels, config, fn2 + config["dataset"] + mode)
    feature_importance = plot_global_feature(x_arr, mask_arr, clabels, config, filename)

    print(feature_importance)
    print(feature_importance2)
    
    if mode in ['test', 'validation']:
        for jj in range(1,2):
            local_filename = "Local_Fold" + str(config["fold_num"]) + "_" + mode + str(jj) + str(clabels[-jj]) + str(x_arr[-jj,-2:].tolist()) +"_mf"
            instancewise_ranking = plot_grid(x_arr[-jj,:].reshape(1,-1), mask_arr[-jj,:].reshape(1,-1), [clabels[-jj]], config, local_filename)


    # Visualise clusters
#     plot_clusters(config, z, clabels, plot_suffix="_inLatentSpace_" + plot_suffix)

    if mode in ['test', 'validation']:
        
        if mode in ['test']:
            acc_pd = pd.DataFrame({'test_acc':acc, 'test_auc':auc}, index=[0])
            acc_pd.to_csv(model._results_path + '/test_acc.csv')
    
    
        # Title of the section to print 
        print(20 * "*" + " Running evaluation using Logistic Regression trained on the joint embeddings" \
                       + " of training set and tested on that of test set" + 20 * "*")
        # Description of the task (Classification scores using Logistic Regression) to print on the command line
        description = "Sweeping C parameter. Smaller C values specify stronger regularization:"
        # Evaluate the embeddings
        #linear_model_eval(config, z_train, y_train, z_test=z, y_test=clabels, description=description)
        
        return z, clabels, feature_importance2, acc, auc, instancewise_ranking, feature_importance

    else:
        # Return z_train = z, and y_train = clabels
        return z, clabels, feature_importance2, feature_importance


def main(config, mask_g=None):
    """Main function for evaluation

    Args:
        config (dict): Dictionary containing options and arguments.

    """
    # Set directories (or create if they don't exist)
    set_dirs(config)
    # Get data loader for first dataset.
    ds_loader = Loader(config, dataset_name=config["dataset"], drop_last=False)
    # Add the number of features in a dataset as the first dimension of the model
    config = update_config_with_model_dims(ds_loader, config)
    # Start evaluation
    feature_importance_tuple = eval(ds_loader, config, mask_g=mask_g)
    
    return feature_importance_tuple


if __name__ == "__main__":
    
    # Get parser / command line arguments
    args = get_arguments()
    seeds = [211, 317, 79, 54, 654, 34, 167, 4468, 369, 770]

    for SEED in seeds:
        # Get configuration file
        config = get_config(args)
        config["linear"]=False
        config["relu"]=True
        config["relux2"]=False
        config["relux3"]=False
        config["relux5"]=False
        config["c_hdim"] = 1024
        
        if config["linear"]:
            arc_name = "_linear_seed_"
        elif config["relu"]:
            arc_name = "_linear_relu_seed_"            
        elif config["relux2"]:
            arc_name = "_linear_relux2_seed_"               
        elif config["relux3"]:
            arc_name = "_linear_relux3_seed_"       
        elif config["relux5"]:
            arc_name = "_linear_relux5_seed_"  
            
            
        # Overwrite the parent folder name for saving results
        config["framework"] = config["dataset"] + arc_name + "class_" + str(config["c_hdim"]) + "_nv_" + str(SEED) + "_gaussian0.30.3nFront" #"_gaussian"+str(config["noise_level"])+str(config["masking_ratio"]) + "nFront" #+ "Nonoise"
        

        #config["framework"] = config["dataset"] + arc_name + "class_" + str(config["c_hdim"]) + "_nv_" + str(SEED)
#         config["framework"] = config["dataset"] + "_linear_relu_seed_" + str(SEED)        

        
        # Turn off noise when evaluating the performance
        config["add_noise"] = False
        # Disable training the mask
        config["train_mask"] = False
        # Use global mask
        config["use_mask_g"] = True
        config["fold_num"] = 0
        config["test_mode"] = True
        config["validate"] = False
        config["training_data_ratio"] = 1.0
    
        # Summarize config and arguments on the screen as a sanity check
        print_config_summary(config, args)
        #----- Run Training - with or without profiler
        main(config)
