
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from utilities.eval_quant import quantitative_evaluation
from model.networks.AirwayDataset import get_airway_data_for_id, get_airways_for_transport
from model.networks.OASISBrainDataset import get_OASISBrain_pairs_for_transport, get_oasisbrain_data_for_id, get_OASISBrain_for_transport

from utilities.utils import record_prediction, denormalize, denormlize_ds
import pandas as pd
import uncertainty_toolbox as uct

# def record_prediction(pd_data_t1, dict_pred):
#     pd_pred_data_t1 = pd_data_t1.copy()
#     for i_key, i_val in dict_pred.items():
#         pd_pred_data_t1[i_key] = i_val
#     return pd_pred_data_t1



def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, i_val in dict_eval.items():
        dict_flattenned[i_k] = np.round(i_val, 4)
    return dict_flattenned



def individualized_prediction_from_t0(specs: dict):

    ds_test, ds_test_dataloader = load_dataset(specs, which_split='test')
    trained_model = load_trained_model(specs, specs["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()

    savedir = os.path.join(specs["LoggingRoot"], specs["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = specs["Class"]
    device = specs["Device"]

    if 'Airway' in dataset_name:
        load_one_case = get_airway_data_for_id
        list_patient_scans = get_airways_for_transport(specs, ds_test, split='test_multiple')
    elif 'AFQ' in dataset_name:
        load_one_case = get_airway_data_for_id
        list_patient_scans = get_airways_for_transport(specs, ds_test, split='test_multiple')
    elif dataset_name == 'OASISBrain':
        load_one_case = get_oasisbrain_data_for_id
        list_patient_scans = get_OASISBrain_for_transport(specs, ds_test, split='test_multiple')


    list_ind_pred = []
    for current_patient in list_patient_scans:
        youngest_scan_idx = current_patient['youngest_scan']
        print(youngest_scan_idx)
        model_input_t0, gt_t0, df_data_t0 = load_one_case(youngest_scan_idx, ds_test)
        model_input_t0, gt_t0 = model_input_t0.to(device), gt_t0.to(device)
        other_scans = current_patient['other_scans']
        if len(other_scans) != 0:
            other_scans = [current_patient['other_scans'][0]]
        for ith_scan_t1 in other_scans:
            model_input_t1, gt_t1, df_data_t1 = load_one_case(ith_scan_t1, ds_test)
            model_input_t1, gt_t1 = model_input_t1.to(device), gt_t1.to(device)
            f_ind_t1, f_mu_t1, f_from_t0 = trained_model.individualized_prediction(model_input_t0, gt_t0, model_input_t1)


            f_ind_t1_ori = denormalize(ds_=ds_test,
                                   arr_=f_ind_t1.squeeze().detach().cpu().numpy(),
                                   var_name=ds_test.tgt_var_name)

            f_mu_t1_ori = denormalize(ds_=ds_test,
                                  arr_=f_mu_t1.squeeze().detach().cpu().numpy(),
                                  var_name=ds_test.tgt_var_name)

            f_from_t0_ori = denormalize(ds_= ds_test,
                                    arr_= f_from_t0.squeeze().detach().cpu().numpy(),
                                    var_name=ds_test.tgt_var_name)

            gt_t1_ori = denormalize(ds_= ds_test,
                                    arr_= gt_t1.squeeze().detach().cpu().numpy(),
                                    var_name=ds_test.tgt_var_name)


            # f_ind_t1 = f_ind_t1.squeeze().detach().cpu().numpy()
            # f_mu_t1 = f_mu_t1.squeeze().detach().cpu().numpy()
            # f_from_t0 = f_from_t0.squeeze().detach().cpu().numpy()
            # gt_t1 = gt_t1.squeeze().detach().cpu().numpy()


            dict_current_pred = {'Ind_Pred': f_ind_t1, 'Pop_Pred': f_mu_t1, "T0_Pred": f_from_t0,
                                 'Ind_Pred_ori': f_ind_t1_ori, 'Pop_Pred_ori': f_mu_t1_ori , "T0_Pred_ori": f_from_t0_ori,

                                 'normed_GT': gt_t1,
                                 'normed_GT_ori': gt_t1_ori}

            current_ind_pred = record_prediction(denormlize_ds(ds_test, df_data_t1), dict_pred=dict_current_pred)

            list_ind_pred.append(current_ind_pred)

    df_ind_analysis = pd.concat(list_ind_pred)
    savepath_ind_analysis = os.path.join(savedir, 'ind_pred.csv')
    df_ind_analysis.to_csv(savepath_ind_analysis)

    return savedir


def eval_general(dir):
    filename_pred_rst = os.path.join(dir, 'ind_pred.csv')
    df_rst = pd.read_csv(filename_pred_rst)


    T0_Pred = df_rst['T0_Pred_ori'].values
    Pop_Pred = df_rst['Pop_Pred_ori'].values
    Ind_Pred = df_rst['Ind_Pred_ori'].values
    normed_GT = df_rst['normed_GT_ori'].values

    # Compute all uncertainty metrics
    metrics_T0 = uct.metrics.get_all_accuracy_metrics(T0_Pred, normed_GT)
    metrics_pop = uct.metrics.get_all_accuracy_metrics(Pop_Pred, normed_GT)
    metrics_ind = uct.metrics.get_all_accuracy_metrics(Ind_Pred, normed_GT)

    list_summary = []
    # t0
    dict_flattened = {}
    dict_flattened['trend'] = 't0'
    dict_flattened.update(flatten_dicts(metrics_T0))
    list_summary.append(dict_flattened)

    dict_flattened = {}
    dict_flattened['trend'] = 'pop'
    dict_flattened.update(flatten_dicts(metrics_pop))
    list_summary.append(dict_flattened)

    dict_flattened = {}
    dict_flattened['trend'] = 'ind'
    dict_flattened.update(flatten_dicts(metrics_ind))
    list_summary.append(dict_flattened)

    savepath = os.path.join(dir, 'ind_all_summary.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)



def eval_airway(dir):
    eval_general(dir)
    filename_pred_rst = os.path.join(dir, 'ind_pred.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    # slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    # list_pos = np.percentile(np.unique(df_rst['pos'].values), slt_percentiles)


    list_landmark_names = list(LANDMARKS_Airway.keys())

    list_summary = []
    #for ith_pos in range(len(list_pos)):
    for ith_pos in list_landmark_names: #range(len(list_pos)):

        # T0_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['T0_Pred'].values
        # Pop_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['Pop_Pred'].values
        # Ind_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['Ind_Pred'].values
        # normed_GT = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['normed_GT'].values

        T0_Pred = df_rst[(df_rst['pos'] - LANDMARKS_Airway[ith_pos]).abs() <= 0.05]['T0_Pred_ori'].values
        Pop_Pred = df_rst[(df_rst['pos'] - LANDMARKS_Airway[ith_pos]).abs() <= 0.05]['Pop_Pred_ori'].values
        Ind_Pred = df_rst[(df_rst['pos'] - LANDMARKS_Airway[ith_pos]).abs() <= 0.05]['Ind_Pred_ori'].values
        normed_GT = df_rst[(df_rst['pos'] - LANDMARKS_Airway[ith_pos]).abs() <= 0.05]['normed_GT_ori'].values


        # Compute all uncertainty metrics
        metrics_T0 = uct.metrics.get_all_accuracy_metrics(T0_Pred, normed_GT)
        metrics_pop = uct.metrics.get_all_accuracy_metrics(Pop_Pred, normed_GT)
        metrics_ind = uct.metrics.get_all_accuracy_metrics(Ind_Pred, normed_GT)


        dict_flattened = {}
        dict_flattened['landmark'] = ith_pos
        dict_flattened['trend'] = 't0'
        dict_flattened.update(flatten_dicts(metrics_T0))
        list_summary.append(dict_flattened)

        dict_flattened = {}
        dict_flattened['landmark'] = ith_pos
        dict_flattened['trend'] = 'pop'
        dict_flattened.update(flatten_dicts(metrics_pop))
        list_summary.append(dict_flattened)

        dict_flattened = {}
        dict_flattened['landmark'] = ith_pos
        dict_flattened['trend'] = 'ind'
        dict_flattened.update(flatten_dicts(metrics_ind))
        list_summary.append(dict_flattened)


    savepath = os.path.join(dir, 'ind_pos_wise_summary.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)

    return


def eval_from_t0(dataset_name):
    if dataset_name == 'Airway':
        return eval_airway
    else:
        return eval_general


def ind_pred_and_eval_from_t0(specs_filename: str, cv_idx: int=None):
    specs = load_json(specs_filename, cv_idx)
    savedir = individualized_prediction_from_t0(specs)
    dataset_name = specs["Class"]
    eval_from_t0(dataset_name)(savedir)
    return



if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0121.json',
        #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v1/brain_lucidatlas_part.json", #'/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0123_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_plainmlp_v14_0123_part.json',
        default='/playpen-raid/Author/LucidAtlas/configs/airways/cv5_es/airway_lucidatlas_full.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )

    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=2,
        help="whether to vis",
    )


    args = arg_parser.parse_args()


    if args.whether_test:
        ind_pred_and_eval_from_t0(args.experiment_directory, cv_idx=args.which_fold)
    print('1')


