
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os

import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from utilities.eval_quant import quantitative_evaluation
from model.networks.AirwayDataset import get_airway_data_for_id, get_airways_for_transport, get_airways_pairs_for_transport
from model.networks.ADNIHPDataset import get_ADNIHP_pairs_for_transport, get_adnihp_data_for_id
from model.networks.OASISBrainDataset import get_OASISBrain_pairs_for_transport, get_oasisbrain_data_for_id
from utilities.utils import record_prediction, denormalize
import pandas as pd
import uncertainty_toolbox as uct



def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, i_val in dict_eval.items():
        dict_flattenned[i_k] = np.round(i_val, 4)
    return dict_flattenned



def individualized_prediction_timeline_pairs(specs_filename):
    spec = load_json(specs_filename)
    ds_test, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='test')
    trained_model = load_trained_model(specs_filename, spec["SavedCheckpointPath"]) # saved checkpoint name
    trained_model.eval()
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    tgt_var_name = spec["TargetVariableName"]
    device = spec["Device"]

    if dataset_name == 'Airway':
        load_one_case = get_airway_data_for_id
        list_patient_scans = get_airways_pairs_for_transport(spec, ds_test, split='test_multiple')
    elif dataset_name == 'ADNIHP':
        load_one_case = get_adnihp_data_for_id
        list_patient_scans = get_ADNIHP_pairs_for_transport(spec, ds_test, split='test_multiple')
    elif dataset_name == 'OASISBrain':
        load_one_case = get_oasisbrain_data_for_id
        list_patient_scans = get_OASISBrain_pairs_for_transport(spec, ds_test, split='test_multiple')


    list_ind_pred = []
    for current_patient in list_patient_scans:
        src_scan_idx = current_patient['src_scan']
        print(src_scan_idx)
        model_input_t0, gt_t0, df_data_t0 = load_one_case(src_scan_idx, ds_test)
        model_input_t0, gt_t0 = model_input_t0.to(device), gt_t0.to(device)
        tgt_scan_idx = current_patient['tgt_scan']

        model_input_t1, gt_t1, df_data_t1 = load_one_case(tgt_scan_idx, ds_test)
        model_input_t1, gt_t1 = model_input_t1.to(device), gt_t1.to(device)
        f_ind_t1, f_mu_t1, f_from_t0 = trained_model.individualized_prediction(model_input_t0, gt_t0, model_input_t1)

        # f_ind_t1 = denormalize(ds_=ds_test,
        #                        arr_=f_ind_t1.squeeze().detach().cpu().numpy(),
        #                        var_name=ds_test.tgt_var_name)
        #
        # f_mu_t1 = denormalize(ds_=ds_test,
        #                       arr_=f_mu_t1.squeeze().detach().cpu().numpy(),
        #                       var_name=ds_test.tgt_var_name)
        #
        # f_from_t0 = denormalize(ds_=ds_test,
        #                         arr_=f_from_t0.squeeze().detach().cpu().numpy(),
        #                         var_name=ds_test.tgt_var_name)
        #
        # gt_t1 = denormalize(ds_=ds_test,
        #                     arr_=gt_t1.squeeze().detach().cpu().numpy(),
        #                     var_name=ds_test.tgt_var_name)

        f_ind_t1 = f_ind_t1.squeeze().detach().cpu().numpy()
        f_mu_t1 = f_mu_t1.squeeze().detach().cpu().numpy()
        f_from_t0 = f_from_t0.squeeze().detach().cpu().numpy()
        gt_t1 = gt_t1.squeeze().detach().cpu().numpy()


        dict_current_pred = {'Ind_Pred': f_ind_t1, 'Pop_Pred': f_mu_t1, "T0_Pred": f_from_t0, 'normed_GT': gt_t1}

        current_ind_pred = record_prediction(df_data_t1, dict_pred=dict_current_pred)

        list_ind_pred.append(current_ind_pred)

    df_ind_analysis = pd.concat(list_ind_pred)
    savepath_ind_analysis = os.path.join(savedir, 'ind_pred_pair.csv')
    df_ind_analysis.to_csv(savepath_ind_analysis)

    return savedir


def eval_timeline_pairs(dataset_name):
    if dataset_name == 'Airway':
        return eval_airway
    else:
    #elif dataset_name == "ADNIHP":
        return eval_general




def eval_airway(dir):
    eval_general(dir)
    filename_pred_rst = os.path.join(dir, 'ind_pred_pair.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    list_pos = np.percentile(np.unique(df_rst['pos'].values), slt_percentiles)


    list_summary = []
    for ith_pos in range(len(list_pos)):
        T0_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['T0_Pred'].values
        Pop_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['Pop_Pred'].values
        Ind_Pred = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['Ind_Pred'].values
        normed_GT = df_rst[(df_rst['pos'] - list_pos[ith_pos]).abs() <= 0.1]['normed_GT'].values


        # Compute all uncertainty metrics
        metrics_t0 = uct.metrics.get_all_accuracy_metrics(T0_Pred, normed_GT)
        metrics_pop = uct.metrics.get_all_accuracy_metrics(Pop_Pred, normed_GT)
        metrics_ind = uct.metrics.get_all_accuracy_metrics(Ind_Pred, normed_GT)

        #t0
        dict_flattened = {}
        dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
        dict_flattened['trend'] = 't0'
        dict_flattened.update(flatten_dicts(metrics_t0))
        list_summary.append(dict_flattened)


        #pop
        dict_flattened = {}
        dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
        dict_flattened['trend'] = 'pop'
        dict_flattened.update(flatten_dicts(metrics_pop))
        list_summary.append(dict_flattened)

        # ind
        dict_flattened = {}
        dict_flattened['pos'] = np.round(list_pos[ith_pos], 4)
        dict_flattened['trend'] = 'ind'
        dict_flattened.update(flatten_dicts(metrics_ind))
        list_summary.append(dict_flattened)


    savepath = os.path.join(dir, 'pair_ind_pos_wise_summary.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)

    return




def eval_general(dir):
    filename_pred_rst = os.path.join(dir, 'ind_pred_pair.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    T0_Pred = df_rst['T0_Pred'].values
    Pop_Pred = df_rst['Pop_Pred'].values
    Ind_Pred = df_rst['Ind_Pred'].values
    normed_GT = df_rst['normed_GT'].values

    # Compute all uncertainty metrics
    metrics_T0 = uct.metrics.get_all_accuracy_metrics(T0_Pred, normed_GT)
    metrics_pop = uct.metrics.get_all_accuracy_metrics(Pop_Pred, normed_GT)
    metrics_ind = uct.metrics.get_all_accuracy_metrics(Ind_Pred, normed_GT)

    list_summary = []
    dict_flattened = {}
    dict_flattened['trend'] = 't0'
    dict_flattened.update(flatten_dicts(metrics_T0))
    list_summary.append(dict_flattened)


    dict_flattened = {}
    dict_flattened['trend'] = 'pop'
    dict_flattened.update(flatten_dicts(metrics_pop))
    list_summary.append(dict_flattened)

    dict_flattened = {}
    dict_flattened['trend'] = 'ind'
    dict_flattened.update(flatten_dicts(metrics_ind))
    list_summary.append(dict_flattened)

    savepath = os.path.join(dir, 'pair_ind_all_summary.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return



def ind_pred_and_eval_timeline_pairs(specs_filename: str,):
    savedir = individualized_prediction_timeline_pairs(specs_filename)
    spec = load_json(specs_filename)
    dataset_name = spec["Class"]
    eval_timeline_pairs(dataset_name)(savedir)
    return



if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/brain_lucidatlas_v14_0122.json", #'/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0121.json', # #
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0123_full.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )



    args = arg_parser.parse_args()


    if args.whether_test:
        ind_pred_and_eval_timeline_pairs(args.experiment_directory)
    print('1')


