import argparse
import json
import os
import torch
import model.networks.basics.workspace as ws
from model.networks import *
from utilities.utils import load_model_in_pkl
from interpret.glassbox import ExplainableBoostingRegressor
from model.comparisons import *
import torch.utils.data as data_utils
from copy import deepcopy
from functools import partial
def load_trained_model(specs: dict,
                       filename_checkpoint: str):


    # read model specifications
    List_Covariate_Names = specs["CovariateNames"]
    Num_In_Cov_Features = specs["InCovFeatures"]
    Num_In_Geo_Features = specs["InGeoFeatures"]
    Num_of_Hidden_Features = specs["HiddenFeatures"]
    Num_of_Hidden_Layers = specs["HiddenLayers"]
    Num_of_Out_Features = specs["OutFeatures"] # mean and variance
    Dict_Prior_Knowledge = specs['PriorKnowledge']
    Dict_Cor_Prior_Knowledge = specs["CorPriorKnowledge"]
    device = specs['Device']
    root_path = os.path.join(specs['LoggingRoot'], specs['ExperimentName'])


    # load model
    decoder = eval(specs['Network'])(
         list_covariates=List_Covariate_Names,
         dict_prior=Dict_Prior_Knowledge,
         dict_cor_prior=Dict_Cor_Prior_Knowledge,
         in_cov_features=Num_In_Cov_Features,
         in_geo_features=Num_In_Geo_Features,
         hidden_features=Num_of_Hidden_Features,
         hidden_layers=Num_of_Hidden_Layers,
         out_features=Num_of_Out_Features,
         device='cuda:0',
         head='gaussian',
         head_activation='softplus')

    pytorch_total_params = sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print('The total number of parameters: ' + str(pytorch_total_params))


    saved_model_state = torch.load(
        os.path.join(
            root_path, ws.model_params_subdir, filename_checkpoint + ".pth"
        ), map_location=torch.device(device)
    )

    decoder.load_state_dict(saved_model_state)
    decoder = decoder.to(device)
    return decoder


def load_model(specs):

    # read model specifications
    List_Covariate_Names = specs["CovariateNames"]
    Num_In_Cov_Features = specs["InCovFeatures"]
    Num_In_Geo_Features = specs["InGeoFeatures"]
    Num_of_Hidden_Features = specs["HiddenFeatures"]
    Num_of_Hidden_Layers = specs["HiddenLayers"]
    Num_of_Out_Features = specs["OutFeatures"] # mean and variance
    Dict_Prior_Knowledge = specs['PriorKnowledge']
    Dict_Cor_Prior_Knowledge = specs["CorPriorKnowledge"]
    device = specs['Device']

    # load model
    decoder = eval(specs['Network'])(
         list_covariates=List_Covariate_Names,
         dict_prior=Dict_Prior_Knowledge,
         dict_cor_prior=Dict_Cor_Prior_Knowledge,
         in_cov_features=Num_In_Cov_Features,
         in_geo_features=Num_In_Geo_Features,
         hidden_features=Num_of_Hidden_Features,
         hidden_layers=Num_of_Hidden_Layers,
         out_features=Num_of_Out_Features,
         device='cuda:0',
         head='gaussian',
         head_activation='softplus')


    pytorch_total_params = sum(p.numel() for p in decoder.parameters() if p.requires_grad)
    print('The total number of parameters: ' + str(pytorch_total_params))

    decoder = decoder.to(device)
    return decoder



def load_comp_model(specs: str, val_loader: data_utils.DataLoader=None):
    # read model specifications)
    # settings = specs["Settings"]
    # model_class = eval(specs['Network'])(settings)
    settings = specs["Settings"]
    List_Covariate_Names = specs["CovariateNames"]
    Num_In_Geo_Features = specs["InGeoFeatures"]

    if "NAMESB" in specs['Network']:
        settings["log_dir"] = os.path.join(specs["LoggingRoot"], specs["ExperimentName"], 'log')
    if "LANAM" in specs['Network']:
        model_class = eval(specs['Network'])(list_covariates=List_Covariate_Names,
                                             in_geo_features=Num_In_Geo_Features,
                                             val_loader=val_loader, settings=settings)
    else:
        model_class = eval(specs['Network'])(list_covariates=List_Covariate_Names,
                                             in_geo_features=Num_In_Geo_Features,
                                             settings=settings)


    return model_class


def get_X_y(train_dataset: torch.utils.data.Dataset):
    from model.networks.AirwayDataset import make_airway_model_input
    from model.networks.OASISBrainDataset import make_oasisbrain_model_input
    from model.networks.ToyDataset import make_toydata_model_input

    dataset_name = train_dataset.DATASETNANE

    if torch.nn.Module:
        make_model_input = make_airway_model_input
    elif dataset_name == "OASISBrain":
        make_model_input = make_oasisbrain_model_input
    elif dataset_name == "ToyData":
        make_model_input = make_toydata_model_input

    X, y = make_model_input(train_dataset, train_dataset.prepared_data, train_dataset.tgt_var_name)

    return X, y

def load_comp_trained_model(specs: dict, filename_checkpoint: str, ds_train: torch.utils.data.Dataset=None):

    #specs = ws.load_experiment_specifications(specs_filename)
    # read model specifications
    root_path = os.path.join(specs['LoggingRoot'], specs['ExperimentName'])
    filename_checkpoint = os.path.join(root_path, ws.model_params_subdir, filename_checkpoint + ".pth")
    settings = specs["Settings"]
    List_Covariate_Names = specs["CovariateNames"]
    Num_In_Geo_Features = specs["InGeoFeatures"]

    model_class = eval(specs['Network'])(list_covariates=List_Covariate_Names,
                                         in_geo_features=Num_In_Geo_Features,
                                         settings=settings)



    if "LANAM" in specs["Network"]:
        X_train, y_train = get_X_y(train_dataset=ds_train)
        model_class.model.initialize(X=X_train.numpy(), y=y_train.squeeze().numpy())
        model_class.model.column_transformer_.fit(X=X_train.numpy(), y=y_train.squeeze().numpy())
        model_class.model.target_transformer_.fit(X=y_train.numpy())

        #model_class.model._initialize_preprocessors()
        filename_optimizer= os.path.join(root_path, ws.model_params_subdir, specs["SavedOptimizer"]+ ".pth")
        filename_history = os.path.join(root_path, ws.model_params_subdir, specs["SavedHistory"] + ".json")

        model_class.model.load_params(f_params = filename_checkpoint, f_optimizer = filename_optimizer)
        model_class.model.initialize_curvature()
        model_class.model.curvature_.fit(model_class.model.get_iterator(model_class.model.dataset(X_train, y_train), training=False))

        #model_class.model.load_params(f_params=filename_checkpoint)
    elif "NAMESB" not in specs["Network"]:
        trained_model = load_model_in_pkl(filename_checkpoint)
        model_class.model = trained_model

    return model_class #trained_model





def gt_f_cov_pred(dict_gt_func, coords_with_cov, which_cov_name):

    dict_f_mu = dict_gt_func["f_mu"]
    dict_f_var = dict_gt_func["f_var"]
    f_mu = dict_f_mu[which_cov_name](coords_with_cov[..., [0]], coords_with_cov[..., [1]])
    f_var = dict_f_var[which_cov_name](coords_with_cov[..., [0]], coords_with_cov[..., [1]])
    arr_mu_var = torch.cat((f_mu, f_var), dim=-1)

    return arr_mu_var

def gt_g_cov_pred(dict_gt_func, covariate, which_covariate):
    if isinstance(which_covariate, list):
        which_covariate = which_covariate[0]
    dict_g = dict_gt_func["g"]
    mu, var, covariance = dict_g[which_covariate](covariate)
    return  mu, var, covariance

def gt_global_importance(self,
                         dict_gt_func,
                         model_input,
                         set_S,
                         IGNORE_CORR=False):
    if isinstance(set_S, str):
        set_S = [set_S]

    # dict_f_mu = dict_gt_func["f_mu"]
    # dict_f_var = dict_gt_func["f_var"]
    # dict_g = dict_gt_func["g"]
    self.f_cov_pred = partial(gt_f_cov_pred, dict_gt_func=dict_gt_func)
    self.g_corr_pred = partial(gt_g_cov_pred, dict_gt_func=dict_gt_func)

    # collect along feature dimension
    # corresponding to E[y|c_i, x]
    list_E_of_E = []
    # corresponding to E_c[Var(y|c, x)| ci, x)], E_c is obtained by sampling over c for different f^v(c), i.e., predicted uncertainties
    list_E_of_Var = []
    # corresponding to part 1 of V_c(E[y|c, x]| ci, x)), V_c is obtained by sampling over c for f^m(c), i.e., predicted expectations
    list_Var_of_E_part1 = []


    # we iterate all covariate and sample
    for ith_feat in range(len(self.covariate_names)):
        k_name = self.dict_idx_covariates[ith_feat]
        print(k_name)
        if k_name in set_S:
            # for source covariate, we use f(c_i) directly from NAM
            coords_cov = self.concat_geo_and_cov_from_input(model_input, [k_name])
            # arr_mu = dict_f_mu[self.dict_idx_covariates[ith_feat]](coords_cov[..., [0]], coords_cov[..., [1]])
            # arr_var = dict_f_var[self.dict_idx_covariates[ith_feat]](coords_cov[..., [0]], coords_cov[..., [1]])
            # list_E_of_E.append(arr_mu)
            # list_E_of_Var.append(arr_var)
            arr_mu_var = self.f_cov_pred(coords_with_cov=coords_cov, which_cov_name=self.dict_idx_covariates[ith_feat])
            list_E_of_E.append(arr_mu_var[..., [0]])
            list_E_of_Var.append(arr_mu_var[..., [1]])
        else:
            # for jth covariate, we need to sample,
            # 1. sample for terms with p(c_j|c_i)
            if IGNORE_CORR:
                E_of_E, E_of_Var, Var_of_E_part1 = self.E_of_dis_y_j_indp(model_input, set_S, k_name, num_of_samples=5000)
            else:
                E_of_E, E_of_Var, Var_of_E_part1 = self.E_of_dis_y_j_given_c_i(model_input, set_S, k_name, num_of_samples=5000)
            # 1.1, the expectation of f^m(c, x)
            list_E_of_E.append(E_of_E)
            # 1.2, the expectation of f^v(c, x)
            list_E_of_Var.append(E_of_Var)
            # 1.2, the variance of f^e(c, x), the apart depending on p(c_j|c_i)
            list_Var_of_E_part1.append(Var_of_E_part1)


    if not IGNORE_CORR:
        # 2, the variance of f^e(c, x), the part depending on covariances p(c_j, c_k|c_i)
        arr_Var_of_E_part2 = self.V_of_E_given_c_i_part2(model_input, set_S, list_E_of_E, num_of_samples=5000)

    # add along feature dimension
    arr_E = torch.sum(torch.cat(list_E_of_E, dim=-1), dim=-1, keepdim=True)
    arr_E_of_Var = torch.sum(torch.cat(list_E_of_Var, dim=-1), dim=-1, keepdim=True)
    arr_Var_of_E_part1 = torch.sum(torch.cat(list_Var_of_E_part1, dim=-1), dim=-1, keepdim=True)

    # add the part2
    # Var = E_of_Var + Var_of_E = E_of_Var = Var_of_E_part1 + Var_of_E_part2
    if IGNORE_CORR:
        arr_Var = arr_E_of_Var + arr_Var_of_E_part1
    else:
        arr_Var = arr_E_of_Var + arr_Var_of_E_part1 + arr_Var_of_E_part2

    if arr_Var.min()<0:
        print(arr_Var.min())
    return arr_E, arr_Var

def make_gt_model(dataset_, model_trained):
    dict_gt_func = {}
    dict_gt_func["f_mu"] = dataset_.gt_f_mu
    dict_gt_func["f_var"] = dataset_.gt_f_var
    dict_gt_func["g"] = dataset_.gt_g
    model = deepcopy(model_trained)  # 初始化你原来的模型

    # 替换成 toy function
    from types import MethodType

    toy_global_importance = partial(gt_global_importance, dict_gt_func=dict_gt_func)

    model.infer_global_importance = MethodType(toy_global_importance, model)

    return model

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        #default='/playpen-raid/Author/LucidAtlas/configs/airways/lucidatlas_1d_csa_v12_0115.json',
        default='/playpen-raid/Author/LucidAtlas/configs/airways/comp/airway_ebm.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="model_final",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    args = arg_parser.parse_args()
    # model = load_trained_model(
    #     specs_filename=args.experiment_directory,
    #     filename_checkpoint=args.checkpoint
    #     )
    #model = load_model(specs_filename=args.experiment_directory)
    model = load_comp_trained_model(specs_filename=args.experiment_directory,
                                    filename_checkpoint=args.checkpoint
                                    )
    print('trained model loaded')