import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os

import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from utilities.eval_quant import quantitative_evaluation
from utilities.utils import record_prediction, denormalize, denormlize_ds
import uncertainty_toolbox as uct




def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, i_val in dict_eval.items():
        dict_flattenned[i_k] = np.round(i_val, 4)
    return dict_flattenned




def test_(network: torch.nn.Module,
                test_csa_dataset: torch.utils.data.Dataset,
                savedir: str,
          which_set: str):

    utils.cond_mkdir(savedir)

    model_input, arr_gt_csa, df_testset = \
        utils.make_input_for_eval(test_csa_dataset.DATASETNANE)(scalar_dataset=test_csa_dataset, device=network.device)

    f_mu = network.infer_mu_testing(model_input)

    f_mu = denormalize(ds_=test_csa_dataset,
                       arr_=f_mu.squeeze().detach().cpu().numpy(),
                       var_name=test_csa_dataset.tgt_var_name)

    arr_gt_csa = denormalize(ds_=test_csa_dataset,
                            arr_=arr_gt_csa.cpu().numpy(),
                            var_name=test_csa_dataset.tgt_var_name)


    # f_mu = f_mu.squeeze().detach().cpu().numpy()
    # pred_std = f_var.squeeze().detach().sqrt().cpu().numpy()
    # arr_gt_csa = arr_gt_csa.cpu().numpy()

    df_pred_rst = record_prediction(denormlize_ds(test_csa_dataset, df_testset),
                                    {'f_mu': f_mu, 'normed_GT': arr_gt_csa})
    savepath_pred_rst = os.path.join(savedir, f'pop_trend_pred_{which_set}.csv')
    df_pred_rst.to_csv(savepath_pred_rst)

    return df_pred_rst


def eval_(dataset_name):
    if dataset_name == 'Airway':
        return eval_airway
    else:
    #elif dataset_name == "ADNIHP":
        return eval_general



def eval_airway(dir, which_set):
    eval_general(dir, which_set)
    filename_pred_rst = os.path.join(dir, f'pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    list_landmark_names = list(LANDMARKS.keys())

    list_summary = []
    for ith_pos in list_landmark_names: #range(len(list_pos)):

        f_mu = df_rst[(df_rst['pos'] - LANDMARKS[ith_pos]).abs() <= 0.1]['f_mu'].values
        normed_GT = df_rst[(df_rst['pos'] - LANDMARKS[ith_pos]).abs() <= 0.1]['normed_GT'].values

        # Compute all uncertainty metrics
        metrics = uct.metrics.get_all_accuracy_metrics(f_mu, normed_GT)
        dict_flattened = {}
        #dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
        dict_flattened['pos'] = np.round(LANDMARKS[ith_pos], 4)
        dict_flattened['landmark'] = ith_pos

        dict_flattened.update(flatten_dicts(metrics))
        list_summary.append(dict_flattened)
    savepath = os.path.join(dir, f'pos_wise_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)

    return


def eval_general(dir, which_set):
    filename_pred_rst = os.path.join(dir, f'pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    f_mu = df_rst['f_mu'].values
    normed_GT = df_rst['normed_GT'].values

    # Compute all uncertainty metrics
    metrics = uct.metrics.get_all_accuracy_metrics(f_mu, normed_GT)
    list_summary = []
    dict_flattened = {}
    dict_flattened.update(flatten_dicts(metrics))
    list_summary.append(dict_flattened)
    savepath = os.path.join(dir, f'all_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return



def pred_and_eval_simple_nam(specs_filename: str, which_set: str):
    spec = load_json(specs_filename)
    ds_test, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split=which_set)
    trained_model = load_trained_model(specs_filename, spec["SavedCheckpointPath"]) # saved checkpoint name
    trained_model.eval()
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    test_(trained_model, ds_test, savedir, which_set)
    eval_(dataset_name)(savedir, which_set)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_mlp.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v1/brain_nam_part.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_mlp_v1_0123.json', #
        default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_nam_part.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )


    args = arg_parser.parse_args()


    if args.whether_test:
        pred_and_eval_simple_nam(args.experiment_directory, which_set='test')
    print('1')


