import sys
sys.path.append("/playpen-raid/Author/LucidAtlas")
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
from utilities.utils import record_prediction, denormalize, denormlize_ds, denormalize_from_distribution
import uncertainty_toolbox as uct
from functools import partial

def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, dict_i_v in dict_eval.items():
        for i_name, i_val in dict_i_v.items():
            if isinstance(i_val, float):
                dict_flattenned[i_k + '_' + i_name] = np.round(i_val, 4)
    return dict_flattenned



def test_(model_trained: torch.nn.Module,
                test_csa_dataset: torch.utils.data.Dataset,
                savedir: str,
          which_set: str):

    utils.cond_mkdir(savedir)

    #model_input, arr_gt_csa, df_testset = utils.make_input_for_eval_airway(scalar_dataset=test_csa_dataset, device=network.device)
    test_csa_dataset.padding_muter = False
    model_input, arr_gt_csa, df_testset = \
        utils.make_input_for_eval(test_csa_dataset.DATASETNANE)(scalar_dataset=test_csa_dataset, device=model_trained.device)

    LIST_USED_COVARIATES = test_csa_dataset.covariate_names
    model_trained.eval()
    list_f_mu = []
    list_f_var = []
    list_cov_record = []
    list_data = []
    list_gt_csa = []
    for ith_cov in range(len(LIST_USED_COVARIATES)):

        # for ith_input in torch.split(model_input, 1000):
        #     with torch.no_grad():
        #         i_f_mu, i_f_var = network.infer_with_subnetwork(ith_input, ith_cov)
        #     list_f_mu.append(i_f_mu)
        #     list_f_var.append(i_f_var)
        #     list_cov_record += [LIST_USED_COVARIATES[ith_cov]] * i_f_mu.shape[0]

        ## query the model

        with torch.no_grad():
            i_f_mu, i_f_var =  utils.batched_infer_global_importance(model_trained, model_input, LIST_USED_COVARIATES[ith_cov], batch_size=100, IGNORE_CORR=True)
        list_f_mu.append(i_f_mu)
        list_f_var.append(i_f_var)
        list_cov_record += [LIST_USED_COVARIATES[ith_cov]] * i_f_mu.shape[0]


        list_data.append(df_testset)
        list_gt_csa.append(arr_gt_csa)

    f_mu = torch.cat(list_f_mu, dim=0)
    f_var = torch.cat(list_f_var, dim=0)
    arr_gt_csa = torch.cat(list_gt_csa, dim=0)

    mh_map, sh_map = f_mu, torch.sqrt(f_var)
    f_mu_ori, pred_std_ori, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=test_csa_dataset, mu=mh_map, sigma=sh_map)

    arr_gt_csa_ori = denormalize(ds_=test_csa_dataset,
                            arr_=arr_gt_csa.cpu().numpy(),
                            var_name=test_csa_dataset.tgt_var_name)


    f_mu = f_mu.squeeze().detach().cpu().numpy()
    pred_std = f_var.squeeze().detach().sqrt().cpu().numpy()
    arr_gt_csa = arr_gt_csa.cpu().numpy()

    df_pred_rst = record_prediction(denormlize_ds(test_csa_dataset, pd.concat(list_data, axis=0)),
                                    {'Feat': list_cov_record,
                                     'f_mu': f_mu, 'f_std': pred_std, 'normed_GT': arr_gt_csa,
                                     'f_mu_ori': f_mu_ori, 'f_std_ori': pred_std_ori, 'normed_GT_ori': arr_gt_csa_ori})


    savepath_pred_rst = os.path.join(savedir, f'local_pop_trend_pred_{which_set}.csv')
    df_pred_rst.to_csv(savepath_pred_rst)

    return df_pred_rst


def eval_(dataset_name):
    if dataset_name == 'Airway' or "AFQ" in dataset_name:
        return partial(eval_spatial, dataset_name = dataset_name)
    else:
        return eval_general



def eval_spatial(dir, which_set, dataset_name):
    eval_general(dir, which_set)
    CURRENT_LDMS = LANDMARKS[dataset_name]
    filename_pred_rst = os.path.join(dir, f'local_pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    list_landmark_names = list(CURRENT_LDMS.keys())
    list_cov_names = np.unique(list(df_rst["Feat"].values))



    list_summary = []
    for ith_feat in list_cov_names:
        for ith_pos in list_landmark_names: #range(len(list_pos)):
            # f_mu = df_rst[df_rst['pos'] == list_pos[ith_pos]]['f_mu'].values
            # f_std = df_rst[df_rst['pos'] == list_pos[ith_pos]]['f_std'].values
            # normed_GT = df_rst[df_rst['pos'] == list_pos[ith_pos]]['normed_GT'].values

            # f_mu = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['f_mu'].values
            # f_std = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['f_std'].values
            # normed_GT = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.01]['normed_GT'].values

            slt_data = df_rst[(df_rst['pos'] - CURRENT_LDMS[ith_pos]).abs() <= 0.05]
            slt_data = slt_data[slt_data["Feat"] == ith_feat]

            f_mu = slt_data['f_mu'].values
            f_std = slt_data['f_std'].values
            normed_GT = slt_data['normed_GT'].values

            # Compute all uncertainty metrics
            metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)
            dict_flattened = {}
            dict_flattened['pos'] = np.round(CURRENT_LDMS[ith_pos], 4)
            dict_flattened['landmark'] = ith_pos
            dict_flattened['Feat'] = ith_feat

            dict_flattened.update(flatten_dicts(metrics))
            list_summary.append(dict_flattened)
    savepath = os.path.join(dir, f'local_pos_wise_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)

    return


def eval_general(dir, which_set):
    filename_pred_rst = os.path.join(dir, f'local_pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    list_cov_names = np.unique(list(df_rst["Feat"].values))
    list_summary = []

    for ith_feat in list_cov_names:

        df_rst_per_cov = df_rst[df_rst["Feat"]==ith_feat]
        f_mu = df_rst_per_cov['f_mu'].values
        f_std = df_rst_per_cov['f_std'].values
        normed_GT = df_rst_per_cov['normed_GT'].values


        # Compute all uncertainty metrics
        metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)

        dict_flattened = {}
        dict_flattened['Feat'] = ith_feat
        dict_flattened.update(flatten_dicts(metrics))
        list_summary.append(dict_flattened)

    savepath = os.path.join(dir, f'local_all_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return



def pred_and_eval_single_local(specs_filename: str, which_set: str, cv_idx: int=None):
    specs = load_json(specs_filename, cv_idx)

    ds_test, ds_test_dataloader = load_dataset(specs, which_split=which_set)
    trained_model = load_trained_model(specs, specs["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()

    savedir = os.path.join(specs["LoggingRoot"], specs["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = specs["Class"]
    test_(trained_model, ds_test, savedir, which_set)
    eval_(dataset_name)(savedir, which_set)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_mlp.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/v1/airway_namlss.json',
        default='/playpen-raid/Author/LucidAtlas/configs/ToyData/lucidatlas_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_mlp_v1_0123.json', #
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )


    args = arg_parser.parse_args()


    if args.whether_test:
        pred_and_eval_single_local(args.experiment_directory, which_set='test')
    print('1')


