import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 300
import torch
import numpy as np
import vis_utils
from utilities import utils
import os

import argparse
from pipeline.load import *
import torch.utils.data as data_utils



def visualize_per_feat_contribution(spec: dict,
                        model_trained: torch.nn.Module,
                        train_dataset: data_utils.Dataset,
                        test_dataset: data_utils.Dataset=None,
                        pos=None):

    '''
    visualize interpretation, it will be a 2D plot,
    x axis is the feature value, y axis is the feature's contribution
    this function can be used for different datasets
    '''
    LIST_USED_COVARIATES = spec["CovariateNames"]
    device = spec["Device"]
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"])
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]

    # make query samples
    arr_input_grids, x_train, y_train = \
    utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
        train_dataset,
        covs_to_plot=LIST_USED_COVARIATES,
        pos=pos,
        device=device)

    _, x_test, y_test = \
        utils.make_grids_and_dps_for_2d_vis(dataset_name=dataset_name)(
        test_dataset,
        covs_to_plot=LIST_USED_COVARIATES,
        pos=pos,
        device=device)


    # make plots

    x_train, y_train = x_train.cpu(), y_train.cpu()
    x_test, y_test = x_test.cpu(), y_test.cpu()

    for ith_cov in range(LIST_USED_COVARIATES):
        f_mu, f_var = model_trained.infer_with_subnetwork(arr_input_grids, ith_cov)
        f_mu, f_var = f_mu.squeeze().cpu(), f_var.squeeze().cpu()
        mh_map, sh_map = f_mu.detach().numpy(), 2 * np.sqrt(f_var.detach().cpu().numpy())

        vis_utils.plot_regression_all_samples(
            x_train[..., 1 + ith_cov],
            y_train,
            arr_input_grids[..., 1 + ith_cov].cpu(),
            mh_map,
            sh_map / 2,
            x_test[..., 1 + ith_cov],
            y_test,
            savepath=f'{savedir}/{dataset_name}_{ith_pos}_ctb_{LIST_USED_COVARIATES[ith_cov]}.png')

    return





def visualize_per_feat_ctb_airway(spec: dict,
             model_trained: torch.nn.Module,
             train_dataset: data_utils.Dataset,
             test_dataset: data_utils.Dataset):

    '''
    make contribution plots for airway dataset at different depth/landmarks
    '''
    slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    list_pos = np.percentile(train_dataset.train_valid_pos, slt_percentiles)

    for ith_pos in range(len(list_pos)):
        visualize_per_feat_contribution(spec=spec,
                             model_trained=model_trained,
                             train_dataset=train_dataset,
                             test_dataset=test_dataset,
                             pos=list_pos[ith_pos])
    return 0


def visualize_per_feat_ctb(dataset_name):

    if "Airway" in dataset_name:
        return visualize_per_feat_ctb_airway
    else:

        return visualize_per_feat_contribution


def visualize(specs_filename: str,):
    spec = load_json(specs_filename)
    ds_train, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='train')
    ds_test, ds_test_dataloader = load_dataset(specs_filename=specs_filename, which_split='test')
    trained_model = load_trained_model(specs_filename, spec["SavedBestCheckpointPath"]) # saved checkpoint name
    trained_model.eval()
    savedir = os.path.join(spec["LoggingRoot"], spec["ExperimentName"]) #'/playpen-raid/Author/LucidAtlas/figures/v12'
    utils.cond_mkdir(savedir)
    dataset_name = spec["Class"]
    visualize_per_feat_ctb(dataset_name)(spec, trained_model, ds_test, savedir)
    return




if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default='/playpen-raid/Author/LucidAtlas/configs/airways/lucidatlas_1d_csa_v13_0116.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
             + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--train",
        dest="whether_train",
        default=True,
        help="whether to train from scratch",
    )

    arg_parser.add_argument(
        "--test",
        dest="whether_test",
        default=True,
        help="whether to test",
    )

    arg_parser.add_argument(
        "--vis",
        dest="whether_vis",
        default=True,
        help="whether to vis",
    )

    args = arg_parser.parse_args()

    if args.whether_vis:
        test(args.experiment_directory)
    csa_dataset, csa_dataloader = load_dataset(specs_filename=args.experiment_directory,
                                               which_split='train')
    print('1')




