import sys
sys.path.append("/playpen-raid/Author/LucidAtlas/")
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['figure.dpi']= 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from utilities.eval_quant import quantitative_evaluation
from utilities.utils import record_prediction, denormalize, denormlize_ds, denormalize_from_distribution
import uncertainty_toolbox as uct




def flatten_dicts(dict_eval):
    dict_flattenned = {}
    for i_k, dict_i_v in dict_eval.items():
        for i_name, i_val in dict_i_v.items():
            if isinstance(i_val, float):
                dict_flattenned[i_k + '_' + i_name] = np.round(i_val, 4)
    return dict_flattenned



def test_(network: torch.nn.Module,
                test_csa_dataset: torch.utils.data.Dataset,
                savedir: str,
          which_set: str):

    utils.cond_mkdir(savedir)

    model_input, arr_gt_csa, df_testset = \
        utils.make_input_for_eval(test_csa_dataset.DATASETNANE)(scalar_dataset=test_csa_dataset, device=network.device)

    #f_mu, f_var = network.infer_mu_and_var_testing(model_input)
    list_f_mu = []
    list_f_var = []
    list_cov_record = []
    list_data = []
    list_gt_csa = []
    model_trained = network.eval()

    LIST_USED_COVARIATES = test_csa_dataset.covariate_names
    for ith_cov in range(len(LIST_USED_COVARIATES)):

        ## query the model
        with torch.no_grad():
            i_f_mu, i_f_var = utils.batched_infer("global", model_trained, model_input, LIST_USED_COVARIATES[ith_cov], batch_size=400)
        list_f_mu.append(i_f_mu)
        list_f_var.append(i_f_var)
        list_cov_record += [LIST_USED_COVARIATES[ith_cov]] * i_f_mu.shape[0]

        list_data.append(df_testset)
        list_gt_csa.append(arr_gt_csa)

    f_mu = torch.cat(list_f_mu, dim=0)
    f_var = torch.cat(list_f_var, dim=0)
    f_var = torch.clamp(f_var, min=1e-4)
    arr_gt_csa = torch.cat(list_gt_csa, dim=0)


    f_mu, f_std = f_mu, torch.sqrt(f_var)
    f_mu_ori, pred_std_ori, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=test_csa_dataset, mu=f_mu, sigma=f_std)


    arr_gt_csa_ori = denormalize(ds_=test_csa_dataset,
                            arr_=arr_gt_csa.cpu().numpy(),
                            var_name=test_csa_dataset.tgt_var_name)

    f_mu = f_mu.squeeze().detach().cpu().numpy()
    pred_std = f_std.squeeze().detach().cpu().numpy()
    arr_gt_csa = arr_gt_csa.cpu().numpy()

    df_pred_rst = record_prediction(denormlize_ds(test_csa_dataset, pd.concat(list_data, axis=0)),
                                    {'Feat': list_cov_record, 'f_mu': f_mu, 'f_std': pred_std, 'normed_GT': arr_gt_csa,
                                     'f_mu_ori': f_mu_ori, 'f_std_ori': pred_std_ori, 'normed_GT_ori': arr_gt_csa_ori})

    savepath_pred_rst = os.path.join(savedir, f'INTR_1_pop_trend_pred_{which_set}.csv')
    df_pred_rst.to_csv(savepath_pred_rst)

    return df_pred_rst


def mean_filter(df_pred_rst: pd.DataFrame,
            savedir: str,
            which_set: str):
    list_id = np.unique(np.array(df_pred_rst['id'])).astype('str').tolist()
    list_v = []
    for ith_id in list_id:
        df_cur_id = df_pred_rst[df_pred_rst['id'] == ith_id]
        df_cur_id = df_cur_id[np.abs(df_cur_id['pos'] - 0.98) > 0.005]
        list_pos = np.unique(np.array(df_cur_id['pos'])).astype('float').tolist()
        for ith_pos in list_pos:
            df_cur_id_pos = df_cur_id[np.abs(df_cur_id['pos'] - ith_pos)<0.005]

            mean_f_mu = df_cur_id_pos['f_mu'].mean()
            mean_f_mu_ori = df_cur_id_pos['f_mu_ori'].mean()
            mean_f_std = df_cur_id_pos['f_std'].mean()
            mean_f_std_ori = df_cur_id_pos['f_std_ori'].mean()
            mean_normed_GT = df_cur_id_pos['normed_GT'].mean()
            mean_normed_GT_ori = df_cur_id_pos['normed_GT_ori'].mean()


            dict_cur_id_cur = {'id': ith_id,
                               'pos': ith_pos,
                               'v_f_mu': mean_f_mu,
                               'v_f_std': mean_f_std,
                               'v_f_mu_ori': mean_f_mu_ori,
                               'v_f_std_ori': mean_f_std_ori,
                               'v_normed_GT': mean_normed_GT,
                               'v_normed_GT_ori': mean_normed_GT_ori,
                               }
            list_v.append(dict_cur_id_cur)

    df_v = pd.DataFrame.from_records(list_v)
    df_pred_rst = df_pred_rst.merge(df_v, how='inner', on=['id', 'pos'])

    savepath_pred_rst = os.path.join(savedir, f'vol_pop_trend_{which_set}.csv')
    df_pred_rst.to_csv(savepath_pred_rst)

    return df_pred_rst





def selector(test_csa_dataset: torch.utils.data.Dataset,
             df_pred_rst: pd.DataFrame,
             savedir: str,
             which_set: str):
    df_pred_rst = filter_sgs(df_pred_rst)

    list_id = np.unique(np.array(df_pred_rst['id'])).astype('str').tolist()
    list_suspicious_ood = []
    LIST_USED_COVARIATES = test_csa_dataset.covariate_names

    for ith_cov in range(len(LIST_USED_COVARIATES)):
        for ith_id in list_id:
            df_cur_id = df_pred_rst[df_pred_rst['id'] == ith_id & df_pred_rst["Feat"] == LIST_USED_COVARIATES[ith_cov]]
            normed_GT = np.array(df_cur_id['v_normed_GT'].values)
            v_f_mu = np.array(df_cur_id['v_f_mu'].values)
            v_f_std = np.array(df_cur_id['v_f_std'].values)
            from math import log, pi
            percentage = - 0.5 * log(2 * pi) \
                         - np.log(v_f_std) \
                         - (normed_GT - v_f_mu)**2 /  2 * np.clip(v_f_std, a_min=1e-6, a_max=v_f_std.max())**2

            #(normed_GT - v_f_mu) / v_f_std
            dict_cur_id = dict(df_cur_id.iloc[np.argmin(percentage)])
            dict_cur_id['OOD_rating'] = float(np.quantile(percentage, 0.05))

            #dict_cur_id["SGS"] = float(np.array(df_cur_id['SGS'].values)[0])
            list_suspicious_ood.append(dict_cur_id)

    df_ood = pd.DataFrame.from_records(list_suspicious_ood)
    savepath_pred_rst = os.path.join(savedir, f'INTR_1_id_ood_pred_{which_set}.csv')
    df_ood.to_csv(savepath_pred_rst)

    return

def filter_sgs(df_pred_rst: pd.DataFrame):
    CURRENT_LDMS = LANDMARKS["Airway"]
    carina = 0.7 #CURRENT_LDMS['carina']
    sgs = CURRENT_LDMS['subglottis']
    tvc = CURRENT_LDMS['TVC']

    filtered = df_pred_rst[(df_pred_rst['pos'] > tvc) & (df_pred_rst['pos'] < carina)]
    return filtered



def eval_general(dir, which_set):
    filename_pred_rst = os.path.join(dir, f'INTR_1_pop_trend_pred_{which_set}.csv')
    df_rst = pd.read_csv(filename_pred_rst)

    f_mu = df_rst['f_mu'].values
    f_std = df_rst['f_std'].values
    normed_GT = df_rst['normed_GT'].values

    # Compute all uncertainty metrics
    metrics = uct.metrics.get_all_metrics(f_mu, f_std, normed_GT)
    list_summary = []
    dict_flattened = {}
    dict_flattened.update(flatten_dicts(metrics))
    list_summary.append(dict_flattened)
    savepath = os.path.join(dir, f'all_summary_{which_set}.csv')
    pd.DataFrame.from_records(list_summary).to_csv(savepath)
    return


def pred_and_eval_ood_dec_with_feat(specs_filename: str, which_set: str, cv_idx: int=None):
    spec = load_json(specs_filename, cv_idx=cv_idx)
    dataset_name = spec["Class"]
    if "Airway" not in dataset_name:
        return

    spec = get_ood_shapetype(spec)


    ds_test, ds_test_dataloader = load_dataset(specs=spec, which_split=which_set)
    trained_model = load_trained_model(specs=spec, filename_checkpoint=spec["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], dataset_name) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    df_pred_rst = test_(trained_model, ds_test, savedir, which_set)
    df_pred_rst = mean_filter(df_pred_rst, savedir, which_set)
    selector(ds_test, df_pred_rst, savedir, which_set)
    #eval_(dataset_name)(savedir, which_set)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default="/playpen-raid/Author/LucidAtlas/configs/airways/v2_ood/airway_lucidatlas_full_test_prs.json",
        default="/playpen-raid/Author/LucidAtlas/configs/airways/cv5_all/lucidatlas_full.json",
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )
    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=2,
        help="whether to vis",
    )

    args = arg_parser.parse_args()


    if args.whether_test:
        pred_and_eval_ood_dec_with_feat(args.experiment_directory, which_set='test', cv_idx=args.which_fold)
    print('1')


