'''
plot intercorrelation between covariates

'''

'''
plot global feature contribution, i.e.,
suppose other features don't exist in the analysis, how a certain feature influence the outcome.
This is what we got from our marginalization

Author

'''





import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils

import model.networks.basics.workspace as ws
from utilities.utils import denormalize


def plot_correlations(spec: dict,
                        model_trained: torch.nn.Module,
                        ds_train: data_utils.Dataset,
                        ds_test: data_utils.Dataset=None,
                        pos=None):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], ws.vis_cov_itrcor)
    utils.cond_mkdir(savedir)

    arr_input_grids = utils.make_input_for_vis_correlation(ds_train,
                                                           covs_to_plot=LIST_USED_COVARIATES,
                                                           num_of_samples=100,
                                                           device=device)

    for ith_src_cov in range(len(LIST_USED_COVARIATES)):
        for ith_tgt_cov in range(len(LIST_USED_COVARIATES)):
            # make query samples
            f_mu, f_var, _ = model_trained.infer_cov_correlation(arr_input_grids,
                                                              LIST_USED_COVARIATES[ith_src_cov],
                                                              LIST_USED_COVARIATES[ith_tgt_cov])
            f_mu, f_var = f_mu.squeeze().cpu(), f_var.squeeze().cpu()
            mh_map, sh_map = f_mu.detach().numpy(), np.sqrt(f_var.detach().cpu().numpy())
            high_bd_map = mh_map + 2 * sh_map
            low_bd_map = mh_map - 2 * sh_map



            arr_x_train = denormalize(ds_=ds_train,
                                      arr_=ds_train.valid_dict_features[LIST_USED_COVARIATES[ith_src_cov]],
                                      var_name=LIST_USED_COVARIATES[ith_src_cov]).squeeze()
            arr_y_train = denormalize(ds_=ds_train,
                                      arr_=ds_train.valid_dict_features[LIST_USED_COVARIATES[ith_tgt_cov]],
                                      var_name=LIST_USED_COVARIATES[ith_tgt_cov]).squeeze()
            arr_x_grids = denormalize(ds_=ds_train,
                                      arr_=arr_input_grids[..., ith_src_cov].cpu().numpy(),
                                      var_name=LIST_USED_COVARIATES[ith_src_cov]).squeeze()




            mh_map = denormalize(ds_=ds_train,
                                 arr_=mh_map,
                                 var_name=LIST_USED_COVARIATES[ith_tgt_cov]).squeeze()

            high_bd_map = denormalize(ds_=ds_train,
                                      arr_=high_bd_map,
                                      var_name=LIST_USED_COVARIATES[ith_tgt_cov]).squeeze()

            low_bd_map = denormalize(ds_=ds_train,
                                     arr_=low_bd_map,
                                     var_name=LIST_USED_COVARIATES[ith_tgt_cov]).squeeze()




            arr_x_test = denormalize(ds_=ds_train,
                                     arr_=ds_test.valid_dict_features[LIST_USED_COVARIATES[ith_src_cov]],
                                     var_name=LIST_USED_COVARIATES[ith_src_cov]).squeeze()
            arr_y_test = denormalize(ds_=ds_train,
                                     arr_=ds_test.valid_dict_features[LIST_USED_COVARIATES[ith_tgt_cov]],
                                     var_name=LIST_USED_COVARIATES[ith_tgt_cov]).squeeze()



            dict_info = {"x_axis_name": LIST_USED_COVARIATES[ith_src_cov] +  DICT_TGT_UNIT[spec["Class"]][LIST_USED_COVARIATES[ith_src_cov]],
                         "y_axis_name":  LIST_USED_COVARIATES[ith_tgt_cov]+ DICT_TGT_UNIT[spec["Class"]][LIST_USED_COVARIATES[ith_tgt_cov]],
                         }




            vis_utils.plot_regression_all_samples(
                arr_x_train,
                arr_y_train,
                arr_x_grids,
                mh_map,
                high_bd_map,
                low_bd_map,
                arr_x_test,
                arr_y_test,
                dict_info=dict_info,
                savepath=f'{savedir}/correlation_src_{LIST_USED_COVARIATES[ith_src_cov]}_tgt_{LIST_USED_COVARIATES[ith_tgt_cov]}.png')


            # vis_utils.plot_regression_all_samples(
            #     ds_train.valid_dict_features[LIST_USED_COVARIATES[ith_src_cov]],
            #     ds_train.valid_dict_features[LIST_USED_COVARIATES[ith_tgt_cov]],
            #     arr_input_grids[..., ith_src_cov].cpu(),
            #     mh_map,
            #     sh_map,
            #     ds_test.valid_dict_features[LIST_USED_COVARIATES[ith_src_cov]],
            #     ds_test.valid_dict_features[LIST_USED_COVARIATES[ith_tgt_cov]],
            #     dict_info=dict_info,
            #     savepath=f'{savedir}/correlation_src_{LIST_USED_COVARIATES[ith_src_cov]}_tgt_{LIST_USED_COVARIATES[ith_tgt_cov]}.png')

    return


def visualize_correlation(specs_filename: str, cv_idx: int=None):
    specs = load_json(specs_filename, cv_idx=cv_idx)
    ds_train, ds_test_dataloader = load_dataset(specs=specs, which_split='train_val')
    ds_test, ds_test_dataloader = load_dataset(specs=specs, which_split='test')
    trained_model = load_trained_model(specs=specs, filename_checkpoint=specs["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()


    savedir = os.path.join(specs["LoggingRoot"], specs["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    plot_correlations(specs, trained_model, ds_train, ds_test)
    return





if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json', #default="/playpen-raid/Author/LucidAtlas/configs/OASISBrain/brain_lucidatlas_v14_0122.json",#default="/playpen-raid/Author/LucidAtlas/configs/adni/adnihp_lucidatlas_v14_0120.json", #'/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0121.json', #  #
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_lucidatlas_v14_0123_full.json',
        default='/playpen-raid/Author/LucidAtlas/configs/airways/airway_namlss_v1_0123_full.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
             + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )

    args = arg_parser.parse_args()

    if args.whether_vis:
        visualize_correlation(args.experiment_directory)
    print('1')




