'''
plot global feature contribution, i.e.,
suppose other features don't exist in the analysis, how a certain feature influence the outcome.
This is what we got from our marginalization

Author

'''





import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils

import model.networks.basics.workspace as ws
from utilities.utils import denormalize, denormlize_ds, normalize

def visualize_shape_per_global_feat_imp_general(spec: dict,
                        model_trained: torch.nn.Module,
                        train_dataset: data_utils.Dataset,
                        test_dataset: data_utils.Dataset=None,
                        cov=0):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], ws.vis_shape_wise)
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    savepath = f'{savedir}/{dataset_name}_global_age_{cov}.png'

    cov = normalize(ds_=test_dataset, arr_=cov, var_name='AGE')

    LIST_USED_COVARIATES = ["AGE"]



    for ith_cov in range(len(LIST_USED_COVARIATES)):
        # make query samples
        arr_input_grids, list_train_data = \
            utils.make_grids_and_dps_for_2d_vis_airway_shape(
                train_dataset,
                cov=cov,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                num_of_samples=100,
                device=device)

        _, list_test_data = \
            utils.make_grids_and_dps_for_2d_vis_airway_shape(
                test_dataset,
                cov=cov,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                num_of_samples=100,
                device=device)

        ## query the model
        model_trained = model_trained.eval()
        with torch.no_grad():
            f_mu, f_var = model_trained.infer_global_importance(arr_input_grids, LIST_USED_COVARIATES[ith_cov])
            f_mu, f_var = f_mu.squeeze().cpu(), f_var.squeeze().cpu()


        mh_map, sh_map = f_mu.detach().numpy(), np.sqrt(f_var.detach().cpu().numpy())


        mh_map = denormalize(ds_=train_dataset,
                           arr_=mh_map,
                           var_name=train_dataset.tgt_var_name)

        sh_map = denormalize(ds_=train_dataset,
                           arr_=sh_map,
                           var_name=train_dataset.tgt_var_name, WHETHER_STD=True)

        arr_x_grids = denormalize(ds_=train_dataset,
                           arr_=arr_input_grids[..., [0]].cpu().numpy(),
                           var_name='pos')

        # arr_x_train = x_train[..., in_geo_features + ith_cov]
        # arr_y_train = y_train
        #arr_x_grids = arr_input_grids[..., [0]].cpu().numpy()
        # mh_map = mh_map
        # sh_map = sh_map
        # arr_x_test = x_test[..., in_geo_features + ith_cov]
        # arr_y_test = y_test


        x_axis_name, y_axis_name = LIST_USED_COVARIATES[ith_cov], train_dataset.tgt_var_name
        dict_info = {"x_axis_name": x_axis_name + DICT_TGT_UNIT[dataset_name][x_axis_name],
                     "y_axis_name": y_axis_name + DICT_TGT_UNIT[dataset_name][y_axis_name],
                     }


        vis_utils.plot_airway_shape_with_population(
            list_train_data,
            list_test_data,
            arr_x_grids,
            mh_map,
            sh_map,
            dict_info,
            savepath=savepath)

    return



def visualize_shape_per_global_feat_imp_airway(spec: dict,
             model_trained: torch.nn.Module,
             train_dataset: data_utils.Dataset,
             test_dataset: data_utils.Dataset):

    '''
    make contribution plots for airway dataset at different depth/landmarks
    '''

    list_age = [24, 72, 120, 168, 216]
    for ith_pos in list_age:
        visualize_shape_per_global_feat_imp_general(spec=spec,
                                 model_trained=model_trained,
                                 train_dataset=train_dataset,
                                 cov=ith_pos,
                                 test_dataset=test_dataset)
    return 0


def visualize_shape_per_global_feat_imp(dataset_name):
    # which dataset to use
    if "Airway" in dataset_name:
        return visualize_shape_per_global_feat_imp_airway



def visualize_shape(specs_filename: str,):
    spec = load_json(specs_filename)
    ds_train, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='train')
    ds_test, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='test')
    trained_model = load_trained_model(specs_filename, spec["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    visualize_shape_per_global_feat_imp(dataset_name)(spec, trained_model, ds_train, ds_test)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json', #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/brain_lucidatlas_v14_0122.json",#default="/playpen-raid/Author/LucidAtlas/configs/adni/adnihp_lucidatlas_v14_0120.json", #'/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0121.json', #  #
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/v0/airway_lucidatlas_v14_0123_part.json',
        #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v1/brain_lucidatlas_part.json",
        #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/v1/brain_lucidatlas_part.json",
        default='/playpen-raid/Author/LucidAtlas/configs/airways/v2/airway_lucidatlas_part.json',

        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
             + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )

    args = arg_parser.parse_args()

    if args.whether_vis:
        visualize_shape(args.experiment_directory)
    print('1')




