'''
plot local feature contribution, i.e.,
other feature are fixed, how a certain feature influence the outcome
Author

'''

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os
import argparse
from pipeline.load import *
import torch.utils.data as data_utils
import model.networks.basics.workspace as ws


def visualize_per_local_feat_ctb_general(spec: dict,
                        model_trained: torch.nn.Module,
                        train_dataset: data_utils.Dataset,
                        test_dataset: data_utils.Dataset=None,
                        pos=None):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"], ws.vis_local_ctb_dir)
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    in_geo_features = spec["InGeoFeatures"]


    for ith_cov in range(len(LIST_USED_COVARIATES)):
        # make query samples
        arr_input_grids, x_train, y_train = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                train_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                device=device)

        _, x_test, y_test = \
            utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
                test_dataset,
                covs_to_plot=[LIST_USED_COVARIATES[ith_cov]],
                pos=pos,
                device=device)

        # make plots
        x_train, y_train = x_train.cpu(), y_train.cpu()
        x_test, y_test = x_test.cpu(), y_test.cpu()



        x_val, y_val, y_up, y_lo = model_trained.infer_with_subnetwork(arr_input_grids, ith_cov)
        x_val, y_val, y_up, y_lo = x_val.squeeze().cpu(), y_val.squeeze().cpu(), y_up.squeeze().cpu(), y_lo.squeeze().cpu()

        if pos is not None:
            savepath = f'{savedir}/{dataset_name}_{np.round(pos, 2)}_local_{LIST_USED_COVARIATES[ith_cov]}.png'
        else:
            savepath = f'{savedir}/{dataset_name}_local_{LIST_USED_COVARIATES[ith_cov]}.png'


        vis_utils.plot_regression_all_samples_with_bounds(
            x_train[..., in_geo_features + ith_cov],
            y_train,
            x_val,
            y_val,
            y_up,
            y_lo,
            x_test[..., in_geo_features + ith_cov],
            y_test,
            savepath=savepath)

    return





def visualize_per_local_feat_ctb_airway(spec: dict,
             model_trained: torch.nn.Module,
             train_dataset: data_utils.Dataset,
             test_dataset: data_utils.Dataset):

    '''
    make contribution plots for airway dataset at different depth/landmarks
    '''
    slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    list_pos = np.percentile(train_dataset.train_valid_pos, slt_percentiles)

    for ith_pos in range(len(list_pos)):
        visualize_per_local_feat_ctb_general(spec=spec,
                             model_trained=model_trained,
                             train_dataset=train_dataset,
                             test_dataset=test_dataset,
                             pos=list_pos[ith_pos])
    return 0


def visualize_per_local_feat_ctb(dataset_name):
    # which dataset to use
    if "Airway" in dataset_name:
        return visualize_per_local_feat_ctb_airway
    else:

        return visualize_per_local_feat_ctb_general



def visualize_local(specs_filename: str,):
    spec = load_json(specs_filename)
    ds_train, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='train')
    ds_test, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='test')
    trained_model = load_comp_trained_model(specs_filename, spec["SavedBestCheckpointPath"]) # saved checkpoint name
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    visualize_per_local_feat_ctb(dataset_name)(spec, trained_model, ds_train, ds_test)
    return




if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default='/playpen-raid/Author/LucidAtlas/configs/airways/comp/airway_ebm.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
             + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )

    args = arg_parser.parse_args()

    if args.whether_vis:
        visualize_local(args.experiment_directory)
    print('1')




