import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler, SequentialSampler
from scipy.stats import entropy
import matplotlib.pyplot as plt
import seaborn as sns
from shift_evaluate import * # shift evaluation
from sklearn.linear_model import Lasso, LogisticRegression
from scipy.stats import pearsonr
from scipy.special import binom
from shift_generate import *

def shapley_correl_intersec(shapley1, shapley2):
    """ compute correlations between 2 sets of Shapley values, compute intersection for top 5 positive contributions"""
    correl, pval = pearsonr(shapley1, shapley2)
    intersec = len(set(shapley1.sort_values()[-5:].index) & set(shapley2.sort_values()[-5:].index))
    return round(correl,2), pval, intersec

def shapley_sparsity_plus(shapley):
    """ compute sparsity of positive contributions for Shapley values: top 5 positive versus total positive"""
    contribution_plus = shapley[shapley>=0].sort_values(ascending = False)
    sparsity_plus = contribution_plus[:5].sum()/contribution_plus.sum()
    return round(sparsity_plus,2)

def shap_kernel(z, p):
    """ Kernel SHAP weights for z unmasked features among p"""
    return (p-1)/(binom(p, z)*(z*(p-z)))

def kernel_shap_consistency(perturb_df, x_df, shapley_result, seed):
    """
    Compute the feature contributions to uncertainty with Kernel SHAP.
    """
    # remove duplicates from the perturbation dataset
    perturb_df.drop_duplicates(inplace=True)
    perturb_df.reset_index(drop=True, inplace=True)
    # repeat original instance
    x_df_repeat = pd.concat([x_df]*len(perturb_df)).reset_index(drop=True)
    # coalition {0,1} dataset
    similar_df = (x_df_repeat.iloc[:,:-2] == perturb_df.iloc[:,:-2]).astype(int) # -3
    # list of features
    feature_list = similar_df.columns.tolist()
    # concatenate with target variables
    kernel_shap_dataset = pd.concat((similar_df,perturb_df[['EC', 'DENS']]), axis=1) # 'ac'
    # remove rows with only 0s or only 1s (otherwise shap kernel goes to infinity)
    indices_to_drop = similar_df[(similar_df.sum(axis=1)==0) | (similar_df.sum(axis=1)==len(feature_list))].index
    kernel_shap_dataset.drop(indices_to_drop, inplace=True)
    # compute the kernel weights for each row
    kernel_shap_dataset["z"] = kernel_shap_dataset[feature_list].sum(axis=1)
    kernel_shap_dataset["weight"] = 0
    kernel_shap_dataset["weight"] = kernel_shap_dataset.apply(lambda row: shap_kernel(row["z"], len(feature_list)), axis=1)
    # fit linear regression with Lasso regularization (EC)
    ec_lasso = Lasso(alpha=1e-3, fit_intercept=True, precompute=False, copy_X=True, max_iter=1000, tol=0.0001, warm_start=False, positive=False, random_state=seed, selection='cyclic')
    ec_lasso.fit(kernel_shap_dataset[feature_list], kernel_shap_dataset["EC"], sample_weight=kernel_shap_dataset["weight"])
    shapley_result["Kernel Shap (EC)"] = ec_lasso.coef_
    # fit linear regression with Lasso regularization (DENS)
    deep_ens_lasso = Lasso(alpha=1e-3, fit_intercept=True, precompute=False, copy_X=True, max_iter=1000, tol=0.0001, warm_start=False, positive=False, random_state=seed, selection='cyclic')
    deep_ens_lasso.fit(kernel_shap_dataset[feature_list], kernel_shap_dataset["DENS"], sample_weight=kernel_shap_dataset["weight"])
    shapley_result["Kernel Shap (DENS)"] = deep_ens_lasso.coef_
    # # fit linear regression with Lasso regularization (ac)
    # ac_lasso = Lasso(alpha=1e-4, fit_intercept=True, precompute=False, copy_X=True, max_iter=1000, tol=0.0001, warm_start=False, positive=False, random_state=seed, selection='cyclic')
    # ac_lasso.fit(kernel_shap_dataset[feature_list], kernel_shap_dataset["ac"], sample_weight=kernel_shap_dataset["weight"])
    # shapley_result["Kernel Shap (ac)"] = ac_lasso.coef_

    return shapley_result
    
def plot_shapley(shapley_result, type="EC", adjusted = False):
    """plot shapley values"""

    if adjusted:
        var = "Shapley value" + " (" + type + ")-adjusted"
    else:
        var = "Shapley value" + " (" + type + ")"

    # plot representation for top 10 contribution
    shapley_result["Abs Shapley value"] = np.abs(shapley_result[var])
    shapley_top_10 = shapley_result.sort_values(["Abs Shapley value"], ascending = False)[:10]
    for i in shapley_top_10.index:
        if shapley_top_10.loc[i, "feature type"] == "text":
                shapley_top_10.loc[i, "Features"] = '"' + i + '"'
        else:
                shapley_top_10.loc[i, "Features"] = i + " = " + str(shapley_top_10.loc[i, "feature value"])

    # plot color
    shapley_top_10["positive"] =  shapley_top_10[var]>0
    shapley_top_10["color"] = shapley_top_10.positive.map({True:'tomato', False:'cornflowerblue'}).values.tolist()

    # final plot
    sns.set_style('darkgrid',{'grid.color':".6", "grid.linestyle":":"})
    plt.figure(figsize = (6,4))
    sns.set(font_scale = 1.4)
    sns.set_color_codes("bright")
    plt.title("Feature contributions")
    plot = sns.barplot(x = var,
                      y = "Features",
                      data = shapley_top_10,
                      color = "cornflowerblue",
                      palette = shapley_top_10["color"].values.tolist()
                      )

    # plot annotation
    for i in range(len(shapley_top_10)):
      if shapley_top_10[var].iloc[i] >=0:
          plot.annotate(str(round(shapley_top_10[var].iloc[i],3)), xy =(0,0.2+i), fontsize = 14)
      else:
          plot.annotate(str(round(shapley_top_10[var].iloc[i],3)) , xy =(0,0.2+i), fontsize = 14)

    plt.show()

def compute_Shapley(instance_to_explain,
                    base_dataset,
                    model,
                    model_type,
                    value_function,
                    token_var,
                    cat_vars,
                    num_vars,
                    words,
                    M,
                    T,
                    conv_eps,
                    seed,
                    device,
                    display):


    # compute Shapley
    
    """Sampling-based approximation of Shapley values"""
    
    # result and storage dataframes
    shapley_result_prev = pd.DataFrame() # shapley values (memory for each iteration)
    perturb_df = pd.DataFrame() # used to store pertubations x_plus and x_minus
    iter_perturb = 0 # index for perturb_df (perturbation samples)
    
    # input with only model features
    x = instance_to_explain[cat_vars + num_vars + [token_var] ]
    y = instance_to_explain['Y']
    
    # number of tabular and numerical features
    n_features_tab = len(cat_vars + num_vars)
    n_features_cat = len(cat_vars)
    n_features_num = len(num_vars)
    
    # seed (for reproducibility and variety of sampling)
    np.random.seed(seed)
    seed_list = np.random.randint(0,len(base_dataset),2*M*T*(n_features_tab + len(instance_to_explain[token_var])))
    
    for iter_conv in range(1, M+1):
    
        # features
        feature_idxs_tab = list(range(n_features_tab))  # we use indices for tabular features
        if 0 not in x[n_features_tab].tolist():
            pad_id = x[n_features_tab].shape[0]
        else:
            pad_id = x[n_features_tab].tolist().index(0) # for the token index, we search for the first [PAD] (index should stop there)
        feature_idxs_txt = list(range(pad_id)) # we use indices for text features (tokens)
    
        shapley_result = pd.DataFrame() # shapley values
    
        # original feature names, used as columns in perturb_df
        token_var_list = ["token_"+str(i) for i in range(x[n_features_tab].shape[0])][:pad_id]# token list without id 0
    
        # used to count redundant tokens
        token_count = dict() # key is token and value is its count in the text
        word_list = []
    
        # compute contribution of each feature
        for i, j in enumerate(feature_idxs_tab + feature_idxs_txt):
    
            # reinitialize feature lists
            feature_idxs_tab = list(range(n_features_tab))  # we use indices for tabular features
            if 0 not in x[n_features_tab].tolist():
                pad_id = x[n_features_tab].shape[0]
            else:
                pad_id = x[n_features_tab].tolist().index(0) # for the token index, we search for the first [PAD] (index should stop there)
            feature_idxs_txt = list(range(pad_id)) # we use indices for text features (tokens)
    
            # feature to study is tabular
            if i < n_features_tab:
              j_tab = [j]; j_txt = []
              feature_idxs_tab.remove(j) # remove feature with index j
              feature_name = instance_to_explain.index[j]
              feature_value = instance_to_explain.values[j]
              feature_type = "tabular"
    
            # feature to study is textual
            else:
              j_tab = []; j_txt = [j]
              feature_idxs_txt.remove(j) # remove token j
              token_id = x[n_features_tab][j]# find corresponding token ID
              feature_name = words[token_id] # corresponding word
              if feature_name in token_count:
                  token_count[feature_name]+=1
                  feature_name = feature_name+"_"+str(token_count[feature_name])
              else:
                  token_count[feature_name]=1
              word_list.append(feature_name)
    
              feature_value = 1. # presence of word
              feature_type = "text"
    
            # store feature contribution
            ec_marginal_contributions = []
            deep_ens_marginal_contributions = []
            # ac_marginal_contributions = []
    
            # Shapley computation via Monte Carlo iterations
            for iter in range(T):
    
                # draw random sample z from the base dataset
                z_full = base_dataset.sample(1, random_state = seed_list[iter_perturb])
                z = z_full[cat_vars + num_vars + [token_var]].values[0]
    
                # 1. tabular case
                # pick a random subset of features
                random.seed(seed_list[iter_perturb])
                x_idx_tab = random.sample(feature_idxs_tab, random.randint(0, len(feature_idxs_tab)))
                z_idx_tab = [idx for idx in feature_idxs_tab if idx not in x_idx_tab]
    
                # construct two new instances
                x_plus_j_tab = np.array([x[i] if i in x_idx_tab+j_tab else z[i] for i in range(n_features_tab)])
                x_minus_j_tab = np.array([z[i] if i in z_idx_tab+j_tab else x[i] for i in range(n_features_tab)])
    
                # 2. text case
                # pick a random subset of tokens
                random.seed(seed_list[iter_perturb])
                n_features_txt = len(feature_idxs_txt)
                x_idx_txt = random.sample(feature_idxs_txt, random.randint(0, len(feature_idxs_txt)))
                z_idx_txt = [idx for idx in feature_idxs_txt if idx not in x_idx_txt]
    
                # construct two new instances
                # a) where all tokens in x with index in z_idx_txt are replaced by [MASK] if the token value is not in z
                x_plus_j_txt = x[n_features_tab].copy()
                for txt_idx in z_idx_txt:
                    if x[n_features_tab][txt_idx] not in z[n_features_tab]:
                        x_plus_j_txt[txt_idx] = 103 # '[MASK]' token
    
    
                # b) where all tokens in x with index in z_idx_txt + j_txt are replaced by [MASK] if the token value is not in z
                x_minus_j_txt = x[n_features_tab].copy()
                for txt_idx in z_idx_txt+j_txt:
                    if x[n_features_tab][txt_idx] not in z[n_features_tab]:
                        x_minus_j_txt[txt_idx] = 103 # '[MASK]' token
    
                ## compute x+ and x- (perturbations)
                # x_plus:convert to tensor and loader
                text_plus = torch.tensor([x_plus_j_txt]).to(device)
                mask_plus = (text_plus!=0).long().to(device)
                categorical_plus = torch.tensor([x_plus_j_tab[:-n_features_num]]).int().to(device)
                numerical_plus = torch.tensor([x_plus_j_tab[-n_features_num:]]).float().to(device)
                #  x_minus: convert to tensor and loader
                text_minus = torch.tensor([x_minus_j_txt]).to(device)
                mask_minus = (text_minus!=0).long().to(device)
                categorical_minus = torch.tensor([x_minus_j_tab[:-n_features_num]]).int().to(device)
                numerical_minus = torch.tensor([x_minus_j_tab[-n_features_num:]]).float().to(device)
    
                # model predictions to get CLS and softmax
                model.eval()
                # CLS and softmax for x_plus
                with torch.no_grad():
                    model_output_plus = model(text_plus, mask_plus, categorical_plus, numerical_plus.float())
                if model_type == "LateFuseBERT":
                    pred_plus = model_output_plus[0]
                    p_hat_plus = F.softmax(pred_plus, dim=1)
                    text_cls_plus = model_output_plus[1]
                    tabular_cls_plus = model_output_plus[2]
                    text_tab_cls_plus = torch.cat((text_cls_plus, tabular_cls_plus), dim=1)
                    text_tab_cls_plus_np = text_tab_cls_plus.cpu().numpy()
                    text_tab_cls_softmax_plus = torch.cat((text_tab_cls_plus, p_hat_plus), dim=1)
                    text_tab_cls_softmax_plus_np = text_tab_cls_softmax_plus.cpu().numpy()
                if model_type == "AllTextBERT":
                    pred_plus = model_output_plus[0]
                    p_hat_plus = F.softmax(pred_plus, dim=1)
                    text_tab_cls_plus = model_output_plus[1]
                    text_tab_cls_plus_np = text_tab_cls_plus.cpu().numpy()
                    text_tab_cls_softmax_plus = torch.cat((text_tab_cls_plus, p_hat_plus), dim=1)
                    text_tab_cls_softmax_plus_np = text_tab_cls_softmax_plus.cpu().numpy()
    
                # CLS and softmax for x_minus
                with torch.no_grad():
                    model_output_minus = model(text_minus, mask_minus, categorical_minus, numerical_minus.float())
                if model_type == "LateFuseBERT":
                    pred_minus = model_output_minus[0]
                    p_hat_minus = F.softmax(pred_minus, dim=1)
                    text_cls_minus = model_output_minus[1]
                    tabular_cls_minus = model_output_minus[2]
                    text_tab_cls_minus = torch.cat((text_cls_minus, tabular_cls_minus), dim=1)
                    text_tab_cls_minus_np = text_tab_cls_minus.cpu().numpy()
                    text_tab_cls_softmax_minus = torch.cat((text_tab_cls_minus, p_hat_minus), dim=1)
                    text_tab_cls_softmax_minus_np = text_tab_cls_softmax_minus.cpu().numpy()
                if model_type == "AllTextBERT":
                    pred_minus = model_output_minus[0]
                    p_hat_minus = F.softmax(pred_minus, dim=1)
                    text_tab_cls_minus = model_output_minus[1]
                    text_tab_cls_minus_np = text_tab_cls_minus.cpu().numpy()
                    text_tab_cls_softmax_minus = torch.cat((text_tab_cls_minus, p_hat_minus), dim=1)
                    text_tab_cls_softmax_minus_np = text_tab_cls_softmax_minus.cpu().numpy()
    
                # compute contribution with value functions
                # a. Error Classifier
                ec_plus = value_function["EC"].predict_proba(text_tab_cls_softmax_plus_np)[:,1][0]
                ec_minus = value_function["EC"].predict_proba(text_tab_cls_softmax_minus_np)[:,1][0]
                ec_marginal_contributions.append(ec_plus - ec_minus)
                # b. Deep Ensembles
                nn_model_list = value_function["DENS"]
                M_NN = len(nn_model_list)
                for m in range(M_NN):
                    # predict with each neural network
                    nn_model_list[m].eval()
                    de_pred_plus = nn_model_list[m](text_tab_cls_plus)
                    # compute softmax probabilities
                    de_p_hat_plus = F.softmax(de_pred_plus, dim=1)
                    # sum logits
                    if m == 0:
                      softmax_list_sum_plus = de_p_hat_plus
                    else:
                      softmax_list_sum_plus += de_p_hat_plus
                # compute average over iterations
                softmax_list_avg_plus = softmax_list_sum_plus/M_NN
                # compute Shannon entropy
                shannon_entropy_plus = entropy(softmax_list_avg_plus.cpu().detach().numpy(), base=2, axis=1)
                deep_ens_plus = shannon_entropy_plus[0]
                # b.1 minus
                for m in range(M_NN):
                    # predict with each neural network
                    nn_model_list[m].eval()
                    de_pred_minus = nn_model_list[m](text_tab_cls_minus)
                    # compute softmax probabilities
                    de_p_hat_minus = F.softmax(de_pred_minus, dim=1)
                    # sum logits
                    if m == 0:
                      softmax_list_sum_minus = de_p_hat_minus
                    else:
                      softmax_list_sum_minus += de_p_hat_minus
                # compute average over iterations
                softmax_list_avg_minus = softmax_list_sum_minus/M_NN
                # compute Shannon entropy
                shannon_entropy_minus = entropy(softmax_list_avg_minus.cpu().detach().numpy(), base=2, axis=1)
                deep_ens_minus = shannon_entropy_minus[0]
                deep_ens_marginal_contributions.append(deep_ens_plus - deep_ens_minus)
                # c. average confidence
                # ac_plus = 1-torch.max(p_hat_plus).item()
                # ac_minus = 1-torch.max(p_hat_minus).item()
                # ac_marginal_contributions.append(ac_plus - ac_minus)
    
                # store perturbations (used to compute Kernel SHAP)
                perturb_df.loc[iter_perturb,cat_vars + num_vars] = x_plus_j_tab
                perturb_df.loc[iter_perturb,token_var_list] = x_plus_j_txt[:pad_id]
                perturb_df.loc[iter_perturb,"EC"] = ec_plus
                perturb_df.loc[iter_perturb,"DENS"] = deep_ens_plus
                # perturb_df.loc[iter_perturb,"ac"] = ac_plus
                iter_perturb+=1
    
                perturb_df.loc[iter_perturb,cat_vars + num_vars] = x_minus_j_tab
                perturb_df.loc[iter_perturb,token_var_list] = x_minus_j_txt[:pad_id]
                perturb_df.loc[iter_perturb,"EC"] = ec_minus
                perturb_df.loc[iter_perturb,"DENS"] = deep_ens_minus
                # perturb_df.loc[iter_perturb,"ac"] = ac_minus
                iter_perturb+=1
    
            # compute Shapley value with MC approach: average over the iterations
            ec_shapley_value = sum(ec_marginal_contributions)/len(ec_marginal_contributions)
            deep_ens_shapley_value = sum(deep_ens_marginal_contributions)/len(deep_ens_marginal_contributions)
            # ac_shapley_value = sum(ac_marginal_contributions)/len(ac_marginal_contributions)
    
            # result table
            shapley_result.loc[feature_name, "feature value"] = feature_value
            shapley_result.loc[feature_name, "feature type"] = feature_type
            shapley_result.loc[feature_name, "Shapley value (EC)"] = ec_shapley_value
            shapley_result.loc[feature_name, "Shapley value (DENS)"] = deep_ens_shapley_value
            # shapley_result.loc[feature_name, "Shapley value (AC)"] = ac_shapley_value
    
        # convergence check on EC Shapley value
        if len(shapley_result_prev)>0: # first round :no need to check for convergence
            shapley_result["Shapley value (EC)"] = (shapley_result["Shapley value (EC)"] * T + shapley_result_prev["Shapley value (EC)"] * (iter_conv*T - T)) / (iter_conv*T)
            shapley_result["Shapley value (DENS)"] = (shapley_result["Shapley value (DENS)"] * T + shapley_result_prev["Shapley value (DENS)"] * (iter_conv*T - T)) / (iter_conv*T)
            # shapley_result["Shapley value (AC)"] = (shapley_result["Shapley value (AC)"] * T + shapley_result_prev["Shapley value (AC)"] * (iter_conv*T - T)) / (iter_conv*T)
            
            shapley_maxad_ec = np.max(np.abs(shapley_result["Shapley value (EC)"] - shapley_result_prev["Shapley value (EC)"]))
            shapley_maxad_deep_ens = np.max(np.abs(shapley_result["Shapley value (DENS)"] - shapley_result_prev["Shapley value (DENS)"]))
            # shapley_maxad_ac = np.max(np.abs(shapley_result["Shapley value (AC)"] - shapley_result_prev["Shapley value (AC)"]))
    
            shapley_maxad = max(shapley_maxad_ec, shapley_maxad_deep_ens) # shapley_maxad_ac
            if display:
                print("iteration:", iter_conv*T,"-", "shapley (ec) max abs difference:",round(shapley_maxad,3))
            if shapley_maxad <= conv_eps:
                if display:
                    print("break")
                break
        shapley_result_prev = shapley_result.copy()
    
    # plot Shapley value bar charts
    if display:
        plot_shapley(shapley_result, type="EC")
        plot_shapley(shapley_result, type="DENS")
        # plot_shapley(shapley_result, type="ac")
    
    ## format perturb_df
    # set as type int
    perturb_df[cat_vars + token_var_list ] = perturb_df[cat_vars + token_var_list].astype(int)
    # rename tokens as words
    mapper1 = dict([(key, value) for i, (key, value) in enumerate(zip(token_var_list, word_list))])
    perturb_df.rename(columns = mapper1, inplace=True)
    
    ## prepare x_df (original x instance to compare with perturbations)
    x_df = pd.DataFrame(instance_to_explain).T[cat_vars + num_vars + [token_var, "EC", "DENS"]] # ac
    x_df[word_list] = x[token_var][:pad_id]
    x_df.drop(token_var, axis=1, inplace=True)
    x_df = x_df[cat_vars + num_vars + word_list + ["EC", "DENS"]] # ac
    
    return shapley_result, (x_df, perturb_df), iter_conv