# Code for the 3 ablation experiments
import os
import torch
import numpy as np
import math as m
import pickle as pkl
import matplotlib.pyplot as plt
import argparse

from torch import optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from christoffel import CLOE, ChristoffelScore_loss
from CLOE.autoencoder import train, Autoencoder
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

def test(christoffel_model, x_enc, y, args):
    pred = christoffel_model.score_samples_noreg(x_enc).detach()
    aucroc = roc_auc_score(y_true=y, y_score=pred)
    aucap = average_precision_score(y_true=y, y_score=pred, pos_label=1)
    score = christoffel_model.predict(x_enc).detach()
    f1Score = f1_score(y_true=y, y_pred=score)

    print(f'AU-ROC for Christoffel score: {aucroc}')
    print(f'AP AUC for Christoffel score: {aucap}')
    print(f'F1 Score for Christoffel score: {f1Score}')

    result_path = f"CLOE/results/{args.data_name}/{args.seed}/CLOE_ablation_{args.study}/"
    os.makedirs(result_path, exist_ok=True)
    np.save(
            result_path + "result.npy",
            {
                "AUC ROC": aucroc,
                "AP AUC": aucap,
                "F1 Score": f1Score,
            },
        )

    if args.umap:
        import umap
        perplexity = 30
        metric='euclidean'
        min_dist = 0.1

        umap_ = umap.UMAP(random_state=RANDOM_SEED, n_neighbors=perplexity, metric=metric, min_dist=min_dist,n_components=2)
        umap_.fit(X)
        X_embedded = umap_.transform(X)

        green_true = y == 0
        red_true = y == 1
        fig, ax = plt.subplots(1,2)
        ax[0].scatter(X_embedded[green_true, 0], X_embedded[green_true, 1], c="g", marker='x')
        ax[0].scatter(X_embedded[red_true, 0], X_embedded[red_true, 1], c="r", marker='x')
        ax[0].set_title('Ground truth')
        
        green = score == 0
        red = score == 1
        
        ax[1].scatter(X_embedded[green, 0], X_embedded[green, 1], c="g", marker='x')
        ax[1].scatter(X_embedded[red, 0], X_embedded[red, 1], c="r", marker='x')
        
        ax[1].set_title("Name method")
        fig.set_figwidth(15)
        fig.set_figheight(7)
        plt.savefig(f'CLOE/models_abl/{data_name}_{type_conc}_{int(lambda_CLOE)}/umap_score.png')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=49,
                        help='seed')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--n', type=int, default=None,
                        help='degree of the poynomial to compute the support')
    parser.add_argument('--study', type=int, default=0,
                        help='0 for removing the joint training, 1 for removing the pre training, 2 for no training at all')
    parser.add_argument('--nb-epochs', type=int, default=10,
                        help='Number of epochs for this training step')
    parser.add_argument('--nb-class', type=int, default=8,
                        help='Dimension of the latent space of the autoencoder')
    parser.add_argument('--dataset-path', type=str, default='CLOE/datasets/',
                        help='Path to the dataset (numpy file)')
    parser.add_argument('--data-name', type=str, default='6_cardio',
                        help='Name of the dataset (numpy file)')
    parser.add_argument('--patience', type=int, default=10,
                        help='patience for earlystopping')
    parser.add_argument('--lambda_CLOE', type=int, default=1,
                        help='coefficient lambda in the training loss for the Christoffel function part')
    parser.add_argument('--dim', type=int, default=[500, 500, 2000], nargs='+',
                        help='Dimension of the hidden layer of the encoder in the order')
    parser.add_argument('--type-conc', type=str, default='mean',
                        help='type of the concatenation for all the Christoffel value in the loss : mean, sum or max')
    parser.add_argument('--num-worker', type=int, default=0,
                        help='Number of worker used to train the model')
    parser.add_argument('--umap', type=bool, default=False,
                        help='Save the image of the UMAP representation of the data with inliers in green and outliers in red')


    args = parser.parse_args()

    torch.manual_seed(args.seed)

    # Choose CPU or GPU if available automatically
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        args.device = 'cuda:0'
    else : 
        args.device = 'cpu'

    print(f"Type ablation study: {args.study}")

    # Hyperparameters
    RANDOM_SEED = args.seed
    LEARNING_RATE = args.lr
    n = args.n
    NUM_EPOCHS = args.nb_epochs
    NUM_CLASSES = args.nb_class
    
    study = args.study
    type_conc = args.type_conc
    lambda_CLOE = args.lambda_CLOE
    DIM = args.dim

    # Enable multiprocessing
    NUM_WORKER = args.num_worker
    if NUM_WORKER > 1 :
        torch.set_num_threads(NUM_WORKER) 
        torch.set_num_interop_threads(NUM_WORKER) 

    # Dataset preprocessing
    data_name = args.data_name
    data = np.load(f'{args.dataset_path}{args.data_name}.npz', allow_pickle=True)
    X, y = data['X'], data['y']
    x = torch.from_numpy(StandardScaler().fit_transform(X)).to(args.device)

    # Compute automatically the batch size depending of n and N
    BATCH_SIZE = m.comb(NUM_CLASSES+n, n)

    train_mode = {
        "pre-training": 0,
        "joint-training": 1,
        "compute-support": 2
    }

    if args.study == 0 or args.study == 2:
        dropout_rate = 0.2
        training_step = "pre-training"
        file_save = f'CLOE/models_abl/{data_name}_{type_conc}_{int(lambda_CLOE)}/pretrain'

    elif args.study == 1:
        dropout_rate = 0.0
        training_step = "joint-training"
        file_save = f'CLOE/models_abl/{data_name}_{type_conc}_{int(lambda_CLOE)}/jointrained'
    
    autoencoder = Autoencoder(in_shape=x.shape[1], enc_shape=NUM_CLASSES, DIM = DIM, dropout_rate = dropout_rate).double().to(args.device)
        
    error = ChristoffelScore_loss(n=n, type_conc=type_conc, step_training=train_mode[training_step], random_seed = RANDOM_SEED) 

    optimizer = optim.Adam(autoencoder.parameters(), lr=LEARNING_RATE)
    torch.autograd.set_detect_anomaly(True)

    x_success = x[y==0]
    if x_success.shape[0]<5000:
        test_size = 0.1
    else:
        test_size = 1- 5000/x_success.shape[0]
    X_train_valid, X_test= train_test_split(x_success, test_size=test_size, random_state=RANDOM_SEED)
    X_train, X_valid = train_test_split(X_train_valid,test_size=0.2, random_state=RANDOM_SEED)

    if args.study != 2:
        train(
            model              = autoencoder,
            train_loader       = DataLoader(X_train.to(dtype=torch.float64), batch_size = BATCH_SIZE, num_workers = NUM_WORKER),
            validation_loader  = DataLoader(X_valid.to(dtype=torch.float64)),
            epochs             = NUM_EPOCHS,
            device             = args.device,
            optimizer          = optimizer, 
            loss_function      = error,
            patience_max       = 20,
            train_mode         = train_mode[training_step],
            path_save          = f'CLOE/models_abl/{data_name}_{type_conc}_{int(lambda_CLOE)}')
        
        autoencoder.load_state_dict(torch.load( f'{file_save}.pt'))

    # Compute the support
    with torch.no_grad():
        autoencoder.eval()
        x_encoded = autoencoder.encode(x)
    X_train_valid, X_test = train_test_split(x_encoded[y==0], test_size=test_size, random_state=RANDOM_SEED)
    X_train, X_valid= train_test_split(X_train_valid, test_size=0.2, random_state=RANDOM_SEED)
    christoffel_support = CLOE(n=None, regularization= "max", polynomial_basis = "monomials", inv = 'fpd_inv')
    christoffel_support.fit(X_train, X_valid)

    # Compute the metrics
    test(christoffel_support, x_encoded, y, args)
            