import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dataloader")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/models")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
#########################################################################################
import os
import torch
import pickle
import numpy as np
import global_variables

from tqdm import tqdm
from numpy.linalg import matrix_rank


def evaluate_ranking(dataloader_test, model, model_baseline, training_params):
    model.eval()
    model_baseline.eval()
    global_variables.is_training = False
    criterion = torch.nn.CrossEntropyLoss().to(training_params["device"])
    criterion_baseline = torch.nn.CrossEntropyLoss().to(training_params["device"])

    list_loss = []
    list_loss_baseline = []
    accuracy_baseline, accuracy_dida = [], []

    with torch.no_grad():
        for batch_i, (X, y, params_1, params_2, perf, mf_hc, list_score) in tqdm(enumerate(dataloader_test), total=len(dataloader_test)):
            X, y, params_1, params_2, perf, mf_hc, list_score = X.to(training_params["device"]).squeeze(0), \
                    y.to(training_params["device"]).squeeze(0), \
                    params_1.to(training_params["device"]).squeeze(0), \
                    params_2.to(training_params["device"]).squeeze(0), \
                    perf.to(training_params["device"]).squeeze(0), \
                    mf_hc.to(training_params["device"]).squeeze(0), \
                    list_score.to(training_params["device"]).squeeze(0)

            size_data = X.size(0)
            nb_batch = X.size(0)

            if training_params.extractor == "dida":
                y = y.argmax(2).float()

            assert not (torch.isnan(X).any() or torch.isinf(X).any())
            assert not (torch.isnan(y).any() or torch.isinf(y).any())
            assert not (torch.isnan(params_1).any() or torch.isinf(params_1).any())
            assert not (torch.isnan(params_2).any() or torch.isinf(params_2).any())
            assert not (torch.isnan(perf).any() or torch.isinf(perf).any())

            pred, dida_metafeatures = model(X, y, params_1, params_2)
            pred_baseline = model_baseline(mf_hc, params_1, params_2)

            pred_dida = pred.max(1, keepdim=True)[1]
            correct_dida = pred_dida.eq(perf.long().view_as(pred_dida)).sum().item()
            prediction_baseline = pred_baseline.max(1, keepdim=True)[1]
            correct_baseline = prediction_baseline.eq(perf.long().view_as(prediction_baseline)).sum().item()

            accuracy_dida.append(correct_dida / size_data)
            accuracy_baseline.append(correct_baseline / size_data)

            loss = criterion(pred, perf.long())
            loss_baseline = criterion_baseline(pred_baseline, perf.long())
            list_loss.append(loss.item())
            list_loss_baseline.append(loss_baseline.item())

    return np.mean(list_loss), np.mean(list_loss_baseline), np.mean(accuracy_dida), np.mean(accuracy_baseline)




def evaluate_ranking_old(dataloader_test, model, model_baseline,
                    training_params, X_train_dida, X_train_hc, y_train):
    model.eval()
    model_baseline.eval()
    global_variables.is_training = False
    criterion = torch.nn.CrossEntropyLoss().to(training_params["device"])
    criterion_baseline = torch.nn.CrossEntropyLoss().to(training_params["device"])

    path_x = os.path.join(global_variables.working_dir, "X_data")
    path_y = os.path.join(global_variables.working_dir, "Y_data")
    path_dida_mf = os.path.join(global_variables.working_dir, "MF_Dida")
    path_hc_mf = os.path.join(global_variables.working_dir, "MF_Hc")
    path_pred = os.path.join(global_variables.working_dir, "predictions")
    os.makedirs(path_x, exist_ok=True)
    os.makedirs(path_y, exist_ok=True)
    os.makedirs(path_dida_mf, exist_ok=True)
    os.makedirs(path_hc_mf, exist_ok=True)
    os.makedirs(path_pred, exist_ok=True)


    list_dida_mf, list_params, list_hc_mf, list_perf, accuracy_dida, accuracy_baseline = [], [], [], [], [], []

    with torch.no_grad():
        list_loss = []
        list_loss_baseline = []
        list_X, list_Y = [], []
        for batch_i, (X, y, params_1, params_2, perf, mf_hc, list_score) in tqdm(enumerate(dataloader_test), total=len(dataloader_test)):
            X, y, params_1, params_2, perf, mf_hc, list_score = X.to(training_params["device"]).squeeze(0), \
                    y.to(training_params["device"]).squeeze(0), \
                    params_1.to(training_params["device"]).squeeze(0), \
                    params_2.to(training_params["device"]).squeeze(0), \
                    perf.to(training_params["device"]).squeeze(0), \
                    mf_hc.to(training_params["device"]).squeeze(0), \
                    list_score.to(training_params["device"]).squeeze(0)

            size_data = X.size(0)
            nb_batch = X.size(0)

            if training_params.extractor == "dida":
                y = y.argmax(2).float()

            assert not (torch.isnan(X).any() or torch.isinf(X).any())
            assert not (torch.isnan(y).any() or torch.isinf(y).any())
            assert not (torch.isnan(params_1).any() or torch.isinf(params_1).any())
            assert not (torch.isnan(params_2).any() or torch.isinf(params_2).any())
            assert not (torch.isnan(perf).any() or torch.isinf(perf).any())

            pred, dida_metafeatures = model(X, y, params_1, params_2)
            pred_baseline = model_baseline(mf_hc, params_1, params_2)

            pred_dida = pred.max(1, keepdim=True)[1]
            correct_dida = pred_dida.eq(perf.long().view_as(pred_dida)).sum().item()
            prediction_baseline = pred_baseline.max(1, keepdim=True)[1]
            correct_baseline = prediction_baseline.eq(perf.long().view_as(prediction_baseline)).sum().item()

            list_dida_mf.append(dida_metafeatures.data.cpu().numpy())
            list_params.append(params_1.data.cpu().numpy())
            list_hc_mf.append(mf_hc.data.cpu().numpy())
            list_perf.append(list_score.data.cpu().numpy()[:, 0])

            accuracy_dida.append(correct_dida / size_data)
            accuracy_baseline.append(correct_baseline / size_data)

            loss = criterion(pred, perf.long())
            loss_baseline = criterion_baseline(pred_baseline, perf.long())
            list_loss.append(loss.item())
            list_loss_baseline.append(loss_baseline.item())

            list_X.extend(list(X.data.cpu()))
            list_Y.extend(list(y.data.cpu()))

    # pickle.dump(list_X, open(os.path.join(path_x, "x_data"), "wb"))
    # pickle.dump(list_Y, open(os.path.join(path_y, "y_data"), "wb"))

    global_variables.is_training = True
    # list_dida_mf = np.concatenate(list_dida_mf, axis=0)
    # list_params = np.concatenate(list_params, axis=0)
    # list_hc_mf = np.concatenate(list_hc_mf, axis=0)
    # y_test = np.concatenate(list_perf, axis=0)
    # X_test_dida = np.concatenate([list_dida_mf, list_params], axis=1)
    # X_test_hc = np.concatenate([list_hc_mf, list_params], axis=1)


    # Add scatter plot of prediction random forest
    # score_dida, pred_dida, surrogate = global_variables.compute_score_rf(X_train=X_train_dida, y_train=y_train, X_test=X_test_dida, y_test=y_test, return_pred=True, return_model=True)
    # score_hc, pred_baseline = global_variables.compute_score_rf(X_train=X_train_hc, y_train=y_train, X_test=X_test_hc, y_test=y_test, return_pred=True)
    # image_dida = global_variables.create_scatter_plot(vtrue=y_test, vpred=pred_dida, title="")
    # image_baseline = global_variables.create_scatter_plot(vtrue=y_test, vpred=pred_baseline, title="")
    # global_variables.writer.add_image("Dida RF prediction", image_dida, global_variables.batch_idx)
    # global_variables.writer.add_image("Baseline RF prediction", image_baseline, global_variables.batch_idx)
    # global_variables.writer.add_scalars("Loss_random_forest/test", {
    #     "dida": np.mean(score_dida), "baseline": np.mean(score_hc)
    # }, global_variables.batch_idx)
    # with open(os.path.join(global_variables.working_dir, "test_loss_rf_dida.txt"), "a") as file:
    #     file.write("{0},{1}\n".format(global_variables.batch_idx, score_dida))
    # with open(os.path.join(global_variables.working_dir, "test_loss_rf_hc.txt"), "a") as file:
    #     file.write("{0},{1}\n".format(global_variables.batch_idx, score_hc))
    # np.save(os.path.join(path_pred, "RF_y_pred_dida"), pred_dida)
    # np.save(os.path.join(path_pred, "RF_y_pred_baseline"), pred_baseline)
    # np.save(os.path.join(path_pred, "RF_y_test"), y_test)


    # Add scatter plot of prediction BOHAMIANN
    score_dida, pred_dida = global_variables.compute_score_bohamiann(X_train=X_train_dida, y_train=y_train, X_test=X_test_dida, y_test=y_test, return_pred=True, return_model=False)
    score_hc, pred_baseline = global_variables.compute_score_bohamiann(X_train=X_train_hc, y_train=y_train, X_test=X_test_hc, y_test=y_test, return_pred=True)
    image_dida = global_variables.create_scatter_plot(vtrue=y_test, vpred=pred_dida, title="")
    image_baseline = global_variables.create_scatter_plot(vtrue=y_test, vpred=pred_baseline, title="")
    global_variables.writer.add_image("Dida BOHAMIANN prediction", image_dida, global_variables.batch_idx)
    global_variables.writer.add_image("Baseline BOHAMIANN prediction", image_baseline, global_variables.batch_idx)
    global_variables.writer.add_scalars("Loss_bohamiann/test", {
        "dida": np.mean(score_dida), "baseline": np.mean(score_hc)
    }, global_variables.batch_idx)
    with open(os.path.join(global_variables.working_dir, "test_loss_bohamiann_dida.txt"), "a") as file:
        file.write("{0},{1}\n".format(global_variables.batch_idx, score_dida))
    with open(os.path.join(global_variables.working_dir, "test_loss_bohamiann_hc.txt"), "a") as file:
        file.write("{0},{1}\n".format(global_variables.batch_idx, score_hc))
    np.save(os.path.join(path_pred, "BO_y_pred_dida"), pred_dida)
    np.save(os.path.join(path_pred, "BO_y_pred_baseline"), pred_baseline)
    np.save(os.path.join(path_pred, "BO_y_test"), y_test)

    # Matrix Rank
    global_variables.writer.add_scalars("Maxtrix_rank", {
        "dida": matrix_rank(list_dida_mf), "baseline": matrix_rank(list_hc_mf)
    }, global_variables.batch_idx)

    np.save(os.path.join(path_dida_mf, "dida_metafeatures"), list_dida_mf)
    np.save(os.path.join(path_hc_mf, "handcrafted_metafeatures"), list_hc_mf)

    return np.mean(list_loss), np.mean(accuracy_dida), surrogate
    # return np.mean(list_loss), np.mean(list_loss_baseline), np.mean(accuracy_dida), np.mean(accuracy_baseline), surrogate
