'''
plot global feature contribution, i.e.,
suppose other features don't exist in the analysis, how a certain feature influence the outcome.
This is what we got from our marginalization

Author

'''

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os

import argparse
from pipeline.load import *
import torch.utils.data as data_utils
from copy import deepcopy
import model.networks.basics.workspace as ws
from utilities.utils import denormalize, denormlize_ds, normalize, denormalize_from_distribution, denormlize_arr
from functools import partial
from  pipeline.load import DICT_TGT_UNIT
from pipeline.visualize.vis_3d_in_2d_interpret import interpret_3din2d_feat_per_pos, interpret_3din2d_per_feat_spatial, interpret_3din2d_per_feat_toydata
from pipeline.load import make_gt_model
def interpret_feat_per_pos(
    type_of_interpret: str,
    type_of_vis: str,
    spec: dict,
    model_trained: torch.nn.Module,
    train_dataset: data_utils.Dataset,
    test_dataset: data_utils.Dataset=None,
    dict_pos=None):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    in_geo_features = spec["InGeoFeatures"]

    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], f'{type_of_interpret}_{type_of_vis}')
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]

    if dict_pos is None:
        pos = None
        name = 'all'
        ldm_value = 'none'
    else:
        name = list(dict_pos.keys())[0]
        ldm_value = list(dict_pos.values())[0]
        pos = np.array(normalize(ds_=train_dataset, arr_=np.array([[ldm_value]]), var_name=train_dataset.geo_var_name)[0]) #np.array([ldm_value])# normalize(ds_=test_dataset, arr_=ldm_value, var_name=train_dataset.geo_var_name)
        print(pos)


    for ith_cov in range(len(LIST_USED_COVARIATES)):
        # make query samples
        arr_input_grids, x_train, y_train = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                train_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                num_of_samples=100,
                device=device)

        _, x_test, y_test = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                test_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                num_of_samples=100,
                device=device)

        # make plots
        x_train, y_train = x_train.cpu(), y_train.cpu()
        x_test, y_test = x_test.cpu(), y_test.cpu()

        model_trained = model_trained.eval()
        with torch.no_grad():
            ## query the model
            f_mu, f_var = utils.batched_infer(name_predictor=type_of_interpret,
                                              model_trained=model_trained,
                                              arr_input_grids=arr_input_grids,
                                              cov_name=LIST_USED_COVARIATES[ith_cov],
                                              batch_size=10)

            f_var = torch.clamp(f_var, min=0)
        mh_map, sh_map = f_mu, torch.sqrt(f_var)
        mh_map, sh_map, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=train_dataset, mu=mh_map, sigma=sh_map)


        if pos is not None:
            savepath = f'{savedir}/{dataset_name}_{name}_global_{LIST_USED_COVARIATES[ith_cov]}.png'
        else:
            savepath = f'{savedir}/{dataset_name}_global_{LIST_USED_COVARIATES[ith_cov]}.png'


        arr_x_train = denormalize(ds_=train_dataset,
                           arr_=x_train[..., in_geo_features + ith_cov],
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()
        arr_y_train = denormalize(ds_=train_dataset,
                           arr_=y_train,
                           var_name=train_dataset.tgt_var_name).squeeze()
        arr_x_grids = denormalize(ds_=train_dataset,
                           arr_=arr_input_grids[..., in_geo_features + ith_cov].cpu().numpy(),
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()


        arr_x_test = denormalize(ds_=train_dataset,
                           arr_= x_test[..., in_geo_features + ith_cov],
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()
        arr_y_test = denormalize(ds_=train_dataset,
                           arr_=y_test,
                           var_name=train_dataset.tgt_var_name).squeeze()

        # arr_x_train = x_train[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # arr_y_train = y_train.cpu().detach().numpy()
        # arr_x_grids = arr_input_grids[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # mh_map = mh_map.squeeze().cpu().detach().numpy()
        # high_bd_map = high_bd_map.squeeze()
        # low_bd_map = low_bd_map.squeeze()
        # arr_x_test = x_test[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # arr_y_test = y_test.cpu().detach().numpy()

        x_axis_name, y_axis_name = LIST_USED_COVARIATES[ith_cov], train_dataset.tgt_var_name
        from pipeline.load import DICT_DISPLAY_NAMES
        dict_info = {"x_axis_name": DICT_DISPLAY_NAMES[dataset_name][x_axis_name] + DICT_TGT_UNIT[dataset_name][x_axis_name],
                     "y_axis_name": DICT_DISPLAY_NAMES[dataset_name][y_axis_name] + DICT_TGT_UNIT[dataset_name][y_axis_name],
                     "pos": ldm_value
                     }

        vis_utils.plot_regression_all_samples(
            arr_x_train,
            arr_y_train,
            arr_x_grids,
            mh_map,
            high_bd_map,
            low_bd_map,
            arr_x_test,
            arr_y_test,
            dict_info,
            savepath=savepath)

    return

def interpret_feat_per_pos_toydata(
    type_of_interpret: str,
    type_of_vis: str,
    spec: dict,
    model_trained: torch.nn.Module,
    train_dataset: data_utils.Dataset,
    test_dataset: data_utils.Dataset=None,
    dict_pos=None):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    in_geo_features = spec["InGeoFeatures"]

    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], f'{type_of_interpret}_{type_of_vis}')
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]

    if dict_pos is None:
        pos = None
        name = 'all'
        ldm_value = 'none'
    else:
        name = list(dict_pos.keys())[0]
        ldm_value = list(dict_pos.values())[0]
        pos = np.array(normalize(ds_=train_dataset, arr_=np.array([[ldm_value]]), var_name=train_dataset.geo_var_name)[0]) #np.array([ldm_value])# normalize(ds_=test_dataset, arr_=ldm_value, var_name=train_dataset.geo_var_name)
        print(pos)


    for ith_cov in range(len(LIST_USED_COVARIATES)):
        # make query samples
        arr_input_grids, x_train, y_train = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                train_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                num_of_samples=100,
                device=device)

        _, x_test, y_test = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                test_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                num_of_samples=100,
                device=device)

        # make plots
        x_train, y_train = x_train.cpu(), y_train.cpu()
        x_test, y_test = x_test.cpu(), y_test.cpu()
        print(x_train.shape)
        model_trained = model_trained.eval()
        with torch.no_grad():
            ## query the model
            f_mu, f_var = utils.batched_infer(name_predictor=type_of_interpret,
                                              model_trained=model_trained,
                                              arr_input_grids=arr_input_grids,
                                              cov_name=LIST_USED_COVARIATES[ith_cov],
                                              batch_size=10)

            f_var = torch.clamp(f_var, min=0)
        mh_map, sh_map = f_mu, torch.sqrt(f_var)
        mh_map, sh_map, low_bd_map, high_bd_map = denormalize_from_distribution(ds_=train_dataset, mu=mh_map, sigma=sh_map)


        if pos is not None:
            savepath = f'{savedir}/{dataset_name}_{name}_global_{LIST_USED_COVARIATES[ith_cov]}.png'
        else:
            savepath = f'{savedir}/{dataset_name}_global_{LIST_USED_COVARIATES[ith_cov]}.png'


        arr_x_train = denormalize(ds_=train_dataset,
                           arr_=x_train[..., in_geo_features + ith_cov],
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()
        arr_y_train = denormalize(ds_=train_dataset,
                           arr_=y_train,
                           var_name=train_dataset.tgt_var_name).squeeze()
        arr_x_grids = denormalize(ds_=train_dataset,
                           arr_=arr_input_grids[..., in_geo_features + ith_cov].cpu().numpy(),
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()


        arr_x_test = denormalize(ds_=train_dataset,
                           arr_= x_test[..., in_geo_features + ith_cov],
                           var_name=LIST_USED_COVARIATES[ith_cov]).squeeze()
        arr_y_test = denormalize(ds_=train_dataset,
                           arr_=y_test,
                           var_name=train_dataset.tgt_var_name).squeeze()

        # arr_x_train = x_train[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # arr_y_train = y_train.cpu().detach().numpy()
        # arr_x_grids = arr_input_grids[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # mh_map = mh_map.squeeze().cpu().detach().numpy()
        # high_bd_map = high_bd_map.squeeze()
        # low_bd_map = low_bd_map.squeeze()
        # arr_x_test = x_test[..., in_geo_features + ith_cov].cpu().detach().numpy()
        # arr_y_test = y_test.cpu().detach().numpy()



        gt_model = make_gt_model(dataset_=train_dataset, model_trained=model_trained)
        with torch.no_grad():
            ## query the model
            f_mu_gt, f_var_gt = utils.batched_infer(name_predictor='global',
                                              model_trained=gt_model,
                                              arr_input_grids=torch.from_numpy(denormlize_arr(train_dataset, arr_input_grids)).to(arr_input_grids.device),
                                              cov_name=LIST_USED_COVARIATES[ith_cov],
                                              batch_size=10)
            f_var_gt = torch.clamp(f_var_gt, min=0)

        mh_map_gt, sh_map_gt = f_mu_gt.detach().cpu().numpy().squeeze(), torch.sqrt(f_var_gt).detach().cpu().numpy().squeeze()
        high_bd_gt_map = mh_map_gt + 2 * sh_map_gt
        low_bd_gt_map = mh_map_gt - 2 * sh_map_gt

        x_axis_name, y_axis_name = LIST_USED_COVARIATES[ith_cov], train_dataset.tgt_var_name
        dict_info = {"x_axis_name": DICT_DISPLAY_NAMES[dataset_name][x_axis_name] + DICT_TGT_UNIT[dataset_name][x_axis_name],
                     "y_axis_name": DICT_DISPLAY_NAMES[dataset_name][y_axis_name] + DICT_TGT_UNIT[dataset_name][y_axis_name],
                     "pos": ldm_value
                     }

        vis_utils.plot_regression_all_samples_with_gt(
            arr_x_train,
            arr_y_train,
            arr_x_grids,
            mh_map,
            high_bd_map,
            low_bd_map,
            mh_map_gt,
        high_bd_gt_map,
        low_bd_gt_map,
            arr_x_test,
            arr_y_test,
            dict_info,
            savepath=savepath)

    return


def interpret_per_feat_spatial(
    type_of_interpret: str,
    type_of_vis: str,
    spec: dict,
    model_trained: torch.nn.Module,
    train_dataset: data_utils.Dataset,
    test_dataset: data_utils.Dataset,
    dataset_name: str):

    '''
    make contribution plots for airway dataset at different depth/landmarks
    '''
    CURRENT_LDMS = LANDMARKS[dataset_name]
    list_pos = list(CURRENT_LDMS.keys())

    for ith_pos in list_pos:
        interpret_feat_per_pos(
            type_of_interpret=type_of_interpret,
            type_of_vis=type_of_vis,
            spec=spec,
            model_trained=model_trained,
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            dict_pos={ith_pos: CURRENT_LDMS[ith_pos]})
    return 0

def interpret_per_feat_spatial_toydata(
        type_of_interpret: str,
        type_of_vis: str,
        spec: dict,
         model_trained: torch.nn.Module,
         train_dataset: data_utils.Dataset,
         test_dataset: data_utils.Dataset,
dataset_name: str):

    '''
    make contribution plots for airway dataset at different depth/landmarks
    '''
    CURRENT_LDMS = LANDMARKS[dataset_name]
    list_pos = list(CURRENT_LDMS.keys())

    for ith_pos in list_pos:
        interpret_feat_per_pos_toydata(
             type_of_interpret=type_of_interpret,
            type_of_vis=type_of_vis,
             spec=spec,
             model_trained=model_trained,
             train_dataset=train_dataset,
             test_dataset=test_dataset,
             dict_pos={ith_pos: LANDMARKS_TOY[ith_pos]})
    return 0


def visualize_interpret_per_feat(dataset_name: str, type_of_vis: str):

    if type_of_vis == '1d':
        # which dataset to use
        if "Airway" in dataset_name or "AFQ" in dataset_name:
            return partial(interpret_per_feat_spatial, type_of_vis=type_of_vis, dataset_name=dataset_name)
        elif dataset_name == "ToyData":
            return partial(interpret_per_feat_spatial_toydata, type_of_vis=type_of_vis, dataset_name=dataset_name)
        else:
            return partial(interpret_feat_per_pos, type_of_vis=type_of_vis)

    elif type_of_vis == '2d':
        if "Airway" in dataset_name or "AFQ" in dataset_name:
            return partial(interpret_3din2d_per_feat_spatial, type_of_vis=type_of_vis, dataset_name=dataset_name)
        elif dataset_name == "ToyData":
            return partial(interpret_3din2d_per_feat_toydata, type_of_vis=type_of_vis)
        else:
            return partial(interpret_3din2d_feat_per_pos, type_of_vis=type_of_vis)


def visualize_interpret_feat(specs_filename: str, cv_idx: int=None, type_of_interpret: str='global', type_of_vis: str='1d', sample_size: int=None):
    specs = load_json(specs_filename, cv_idx, sample_size)
    ds_train, ds_test_dataloader = load_dataset(specs=specs, which_split='train_val')
    ds_test, ds_test_dataloader = load_dataset(specs=specs, which_split='test')
    trained_model = load_trained_model(specs=specs, filename_checkpoint=specs["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()

    savedir = os.path.join(specs["LoggingRoot"], specs["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = specs["Class"]
    visualize_interpret_per_feat(dataset_name, type_of_vis)(type_of_interpret=type_of_interpret, spec=specs, model_trained=trained_model, train_dataset=ds_train, test_dataset=ds_test)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default='/playpen-raid/Author/LucidAtlas/configs/airways/cv5_sc/lucidatlas_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/OASISBrain/cv5_sc/lucidatlas_full.json',
        #default='/playpen-raid/Author/LucidAtlas/configs/ToyData/cv5_sc_toy/lucidatlas_full.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
             + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )
    arg_parser.add_argument(
        "--fold",
        "-f",
        dest="which_fold",
        default=0,
        help="whether to vis",
    )

    args = arg_parser.parse_args()

    if args.whether_vis:
        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='global', type_of_vis='1d')
        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='1d')
        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='local', type_of_vis='1d')

        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='global', type_of_vis='2d')
        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='indp', type_of_vis='2d')
        visualize_interpret_feat(args.experiment_directory, cv_idx=args.which_fold, type_of_interpret='local', type_of_vis='2d')


    print('1')




