# %% [markdown]
# ### load data

# %% [markdown]
# ### all params

# %%
label_lst = ['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes']
where_to_unl_lst = ['nose', 'eye', 'noseeye']
unl_epochs_lst = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
ori_training_epochs = 100 
RT_training_epochs = 100 
unlearn_times = 3

# %% [markdown]
# ### load model

# %%
import torch 
import torch.nn as nn 

from utils import ViTBackBone, ViTGenderClassifier


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vit_backbone = ViTBackBone().to(device)
embedding_dim = vit_backbone.backbone.embed_dim

# load original model
from config import get_ori_model_save_path

ori_model_dict = {}
for label in label_lst:
    ori_model_classifier = ViTGenderClassifier(embedding_dim=embedding_dim)
    ori_model_classifier.load_state_dict(torch.load(get_ori_model_save_path(ori_training_epochs=ori_training_epochs, label=label), map_location=device))    
    ori_model_dict[label] = ori_model_classifier



from config import get_our_model_save_path, get_BL2_model_save_path

our_model_dict, BL2_model_dict = {}, {}
for label in label_lst:
    our_model_dict[label], BL2_model_dict[label] = {}, {}
    for where_to_unl in where_to_unl_lst:
        our_model_dict[label][where_to_unl] = {unl_epoch: [] for unl_epoch in unl_epochs_lst}
        BL2_model_dict[label][where_to_unl] = {unl_epoch: [] for unl_epoch in unl_epochs_lst}
        for unl_epoch in unl_epochs_lst:
            for unlearn_times_idx in range(unlearn_times):
                our_model_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
                BL2_model_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
                our_model_classifier.load_state_dict(torch.load(get_our_model_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unl_epoch, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), map_location=device))
                BL2_model_classifier.load_state_dict(torch.load(get_BL2_model_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=unl_epoch, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), map_location=device))
                
                our_model_dict[label][where_to_unl][unl_epoch].append(our_model_classifier)
                BL2_model_dict[label][where_to_unl][unl_epoch].append(BL2_model_classifier)

# %% [markdown]
# ### load accuracy and running time

# %%
import numpy as np 

# our model 
from config import get_our_model_acc_save_path, get_our_model_running_time_save_path
# baseline 2
from config import get_BL2_model_acc_save_path, get_BL2_model_running_time_save_path
# retrained model 
from config import get_RT_model_acc_save_path, get_RT_model_running_time_save_path

our_acc_dict, our_running_time_dict = {}, {}
RT_acc_dict, RT_running_time_dict = {}, {}
BL2_acc_dict, BL2_running_time_dict = {}, {}
for label in label_lst:
    our_acc_dict[label], our_running_time_dict[label] = {}, {}
    RT_acc_dict[label], RT_running_time_dict[label] = {}, {}
    BL2_acc_dict[label], BL2_running_time_dict[label] = {}, {}
    for where_to_unl in where_to_unl_lst:
        our_acc_dict[label][where_to_unl], our_running_time_dict[label][where_to_unl] = {}, {}
        RT_acc_dict[label][where_to_unl] = np.loadtxt(get_RT_model_acc_save_path(RT_training_epochs=RT_training_epochs, where_to_unl=where_to_unl, label=label))
        RT_running_time_dict[label][where_to_unl] = np.loadtxt(get_RT_model_running_time_save_path(RT_training_epochs=RT_training_epochs, where_to_unl=where_to_unl, label=label))
        BL2_acc_dict[label][where_to_unl], BL2_running_time_dict[label][where_to_unl] = {}, {}
        for unl_epoch in unl_epochs_lst:
            our_acc = np.loadtxt(get_our_model_acc_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unl_epoch, where_to_unl=where_to_unl, label=label))
            our_running_time = np.loadtxt(get_our_model_running_time_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unl_epoch, where_to_unl=where_to_unl, label=label))
            BL2_acc = np.loadtxt(get_BL2_model_acc_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=unl_epoch, where_to_unl=where_to_unl, label=label))
            BL2_running_time = np.loadtxt(get_BL2_model_running_time_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=unl_epoch, where_to_unl=where_to_unl, label=label))
            our_acc_dict[label][where_to_unl][unl_epoch] = our_acc
            our_running_time_dict[label][where_to_unl][unl_epoch] = our_running_time
            BL2_acc_dict[label][where_to_unl][unl_epoch] = BL2_acc
            BL2_running_time_dict[label][where_to_unl][unl_epoch] = BL2_running_time
            



# %% [markdown]
# ### get y_pred

# %%
from config import get_our_model_y_pred_save_path, get_BL2_model_y_pred_save_path

our_y_pred_dict, BL2_y_pred_dict = {}, {}
for label in label_lst:
    our_y_pred_dict[label], BL2_y_pred_dict[label] = {}, {}
    for where_to_unl in where_to_unl_lst:
        our_y_pred_dict[label][where_to_unl] = {unl_epoch: [] for unl_epoch in unl_epochs_lst}
        BL2_y_pred_dict[label][where_to_unl] = {unl_epoch: [] for unl_epoch in unl_epochs_lst}
        for unl_epoch in unl_epochs_lst:
            for unlearn_times_idx in range(unlearn_times):
                our_y_pred = np.loadtxt(get_our_model_y_pred_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unl_epoch, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label)) 
                our_y_pred_dict[label][where_to_unl][unl_epoch].append(our_y_pred)
                BL2_y_pred = np.loadtxt(get_BL2_model_y_pred_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=unl_epoch, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label))
                BL2_y_pred_dict[label][where_to_unl][unl_epoch].append(BL2_y_pred)


# %% [markdown]
# ### evaluation utils function

# %%
import pandas as pd 
import matplotlib.pyplot as plt 
import os 
from typing import Tuple 
from copy import deepcopy

def plot_evaluation(evaluation_metric: str, our_eval_res_lst: list, BL2_eval_res_lst: list, where_to_unl: str, label: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    # Extracting means and standard deviations into separate dictionaries
    data_means = {
        "Ours": [x[0] for x in our_eval_res_lst],
        "BL1": [x[0] for x in BL2_eval_res_lst],
        # "BL2": [x[0] for x in BL3_eval_res_lst],
    }

    data_stds = {
        "Ours": [x[1] for x in our_eval_res_lst],
        "BL1": [x[1] for x in BL2_eval_res_lst],
        # "BL2": [x[1] for x in BL3_eval_res_lst],
    }
    
    # Create dataframes for means and standard deviations
    df_means = pd.DataFrame(data_means, index=[str(elem) for elem in unl_epochs_lst]).T
    df_stds = pd.DataFrame(data_stds, index=[str(elem) for elem in unl_epochs_lst]).T
    df_means.columns.name = "unl_epoch"
    df_stds.columns.name = "unl_epoch"
    
    if evaluation_metric == 'EI':
        EI_df_means, EI_df_stds = deepcopy(df_means), deepcopy(df_stds)
        df_means = df_means.iloc[:, 2:]
        df_stds = df_stds.iloc[:, 2:]

    # Apply a clean style
    plt.style.use('seaborn-v0_8-darkgrid')

    # Define markers, line styles, and colors for each model type
    markers = ['o', 's', 'D', '^']
    linestyles = ['-', '--', '-.', ':']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # Custom color palette


    # Plot the data with line plots and shaded standard deviation areas
    plt.figure(figsize=(6, 4))

    # Plot each model with consistent color, marker, and shaded std area
    for (model, marker, linestyle, color) in zip(df_means.index, markers, linestyles, colors):
        means = df_means.loc[model]
        stds = df_stds.loc[model]

        # Line plot for the model
        plt.plot(
            df_means.columns, means, marker=marker, linestyle=linestyle, color=color, 
            linewidth=2.5, label=model
        )
        if evaluation_metric == 'EI': continue
        # Shaded area for standard deviation
        plt.fill_between(df_means.columns, means - stds, means + stds, color=color, alpha=0.2)

    if evaluation_metric == 'EI':
        plt.axhline(y=1, label='EI=1', linewidth=2)
    if evaluation_metric == 'TRI':
        plt.axhline(y=1, label='TRI=1', linewidth=2)
    if evaluation_metric == 'RASI':
        plt.axhline(y=1, label='RASI=1', linewidth=2)
    if evaluation_metric == 'SRI':
        plt.axhline(y=0, label='SRI=0', linewidth=2)
    if evaluation_metric == 'SDI':
        plt.axhline(y=0, label='SDI=0', linewidth=2)
    
    # Add labels, title, and grid
    plt.xlabel("Unlearning Epochs", fontsize=14, fontweight='bold')
    plt.ylabel("{}".format(evaluation_metric), fontsize=14, fontweight='bold')
    
    plt.grid(True, linestyle='--', linewidth=0.6, alpha=0.7)

    # Improve legend
    plt.legend(fontsize=14, loc='best', frameon=True, shadow=True)

    plt.xticks(df_means.columns, fontsize=10)
    # plt.xticks(df_means.columns, fontsize=12, rotation=45)
    plt.yticks(fontsize=12)

    # Adjust layout to fit all elements
    plt.tight_layout()

    # Create output directory and save the figure
    os.makedirs(f'imgs', exist_ok=True)
    plt.savefig(f'imgs/wheretounl{where_to_unl}_label{label}_{evaluation_metric}.pdf', dpi=300, bbox_inches='tight')

    # Close the plot to avoid display issues
    plt.close()

    return [df_means, df_stds] if evaluation_metric != 'EI' else [EI_df_means, EI_df_stds]

# %% [markdown]
# ### evaluation 1: test retention index (TRI)

# %%

# unlearning evaluation metrics 
def cal_test_retention(UL_acc: float, RT_acc: float) -> float:
    return UL_acc / RT_acc

import pandas as pd 


TR_dict = {}
for label in label_lst:
    TR_dict[label] = {}
    for where_to_unl in where_to_unl_lst:
        if label == 'Big_Nose' and where_to_unl != 'nose': continue 
        if label == 'Eyeglasses' and where_to_unl != 'eye': continue
        if label == 'Narrow_Eyes' and where_to_unl != 'eye': continue
        if label == 'Pointy_Nose' and where_to_unl != 'nose': continue
        
        TR_dict[label][where_to_unl] = [pd.DataFrame(), pd.DataFrame()]
        our_TR_lst, BL2_TR_lst = [], []
        RT_acc_avg, RT_acc_std = RT_acc_dict[label][where_to_unl]
        
        for unl_epoch in unl_epochs_lst:
            our_acc_avg, our_acc_std = our_acc_dict[label][where_to_unl][unl_epoch]
            BL2_acc_avg, BL2_acc_std = BL2_acc_dict[label][where_to_unl][unl_epoch]
            # calculate 
            our_TR_avg = cal_test_retention(UL_acc=our_acc_avg, RT_acc=RT_acc_avg)
            BL2_TR_avg = cal_test_retention(UL_acc=BL2_acc_avg, RT_acc=RT_acc_avg)
            # append 
            our_TR_lst.append([our_TR_avg, our_acc_std/RT_acc_avg])
            BL2_TR_lst.append([BL2_TR_avg, BL2_acc_std/RT_acc_avg])
            
        df_means, df_stds = plot_evaluation(evaluation_metric='TRI', our_eval_res_lst=our_TR_lst, BL2_eval_res_lst=BL2_TR_lst, where_to_unl=where_to_unl, label=label)
        
        TR_dict[label][where_to_unl] = [df_means, df_stds]
        
import pickle
with open(f'results_processed/TRI.dict', 'wb') as f:
    pickle.dump(TR_dict, f)

# %% [markdown]
# ### evaluation 2: Efficiency Index (EI)

# %%
# test efficiency
def cal_efficiency(RT_training_time: float, model_training_time: float) -> float :
    return RT_training_time / model_training_time

eff_dict = {}
for label in label_lst:
    eff_dict[label] = {}
    for where_to_unl in where_to_unl_lst:
        if label == 'Big_Nose' and where_to_unl != 'nose': continue 
        if label == 'Eyeglasses' and where_to_unl != 'eye': continue
        if label == 'Narrow_Eyes' and where_to_unl != 'eye': continue
        if label == 'Pointy_Nose' and where_to_unl != 'nose': continue


        eff_dict[label][where_to_unl] = [pd.DataFrame(), pd.DataFrame()]
        our_EI_lst, BL2_EI_lst = [], []
        RT_running_time_avg, RT_running_time_std = RT_running_time_dict[label][where_to_unl]
        
        for unl_epoch in unl_epochs_lst:
            # get time consumption 
            our_time_avg, our_time_std = our_running_time_dict[label][where_to_unl][unl_epoch] 
            BL2_time_avg, BL2_time_std = BL2_running_time_dict[label][where_to_unl][unl_epoch]
            # calculate 
            our_eff_avg = cal_efficiency(RT_training_time=RT_running_time_avg, model_training_time=our_time_avg)
            BL2_eff_avg = cal_efficiency(RT_training_time=RT_running_time_avg, model_training_time=BL2_time_avg)
            # append
            our_EI_lst.append([our_eff_avg, our_time_std]) # RT_running_time_avg*our_time_std/(our_time_avg**2)])
            BL2_EI_lst.append([BL2_eff_avg, BL2_time_std]) # RT_running_time_avg*BL2_time_std/(BL2_time_avg**2)])
        
        df_means, df_stds = plot_evaluation(evaluation_metric='EI', our_eval_res_lst=our_EI_lst, BL2_eval_res_lst=BL2_EI_lst, where_to_unl=where_to_unl, label=label)
        eff_dict[label][where_to_unl] = [df_means, df_stds]


import pickle
with open(f'results_processed/EI.dict', 'wb') as f:
    pickle.dump(eff_dict, f)

# %% [markdown]
# ### evaluation 3: robustness against shuffling index (RASI)

# %%
from sklearn.metrics import accuracy_score
from utils import prep_data, evaluate_shuffle_classifier, evaluate_BL2_RASI_classifier

# from utils import evaluate_classifier, evaluate_BL3

def cal_against_shuffle_retention(y_pred: list, y_pred_with_shuffle: list) -> float:
    acc = accuracy_score(y_true=y_pred, y_pred=y_pred_with_shuffle) 
    
    return acc 


shuffle_times = 5
RASI_dict = {}
for label in label_lst:
    RASI_dict[label] = {}
    for where_to_unl in where_to_unl_lst:
        if label == 'Big_Nose' and where_to_unl != 'nose': continue 
        if label == 'Eyeglasses' and where_to_unl != 'eye': continue
        if label == 'Narrow_Eyes' and where_to_unl != 'eye': continue
        if label == 'Pointy_Nose' and where_to_unl != 'nose': continue
        
        RASI_dict[label][where_to_unl] = [pd.DataFrame(), pd.DataFrame()]
        agg_our_model_RASI_lst, agg_BL2_model_RASI_lst = [], []
        train_attr_dict, test_attr_dict, train_area_dict, test_area_dict = prep_data(label=label, where_to_unl=where_to_unl, retrain_or_shuffle='shuffle')
        
        for unl_epoch in unl_epochs_lst:
            our_model_URAS_lst, BL2_model_URAS_lst = [], []
            for unlearn_times_idx in range(unlearn_times):
                our_model_classifier = our_model_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
                BL2_model_classifier = BL2_model_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
                our_y_pred = our_y_pred_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
                BL2_y_pred = BL2_y_pred_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
                
                cum_our_model_URAS, cum_BL2_model_URAS = .0, .0
                for _ in range(shuffle_times):
                    _, our_y_pred_with_shuffle, _ = evaluate_shuffle_classifier(backbone=vit_backbone, classifier=our_model_classifier, test_attr_dict=test_attr_dict, test_area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
                    _, BL2_y_pred_with_shuffle, _ = evaluate_BL2_RASI_classifier(backbone=vit_backbone, classifier=BL2_model_classifier, test_attr_dict=test_attr_dict, test_area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
                    # calculate URAS
                    our_URAS = cal_against_shuffle_retention(y_pred=our_y_pred, y_pred_with_shuffle=our_y_pred_with_shuffle)
                    BL2_URAS = cal_against_shuffle_retention(y_pred=BL2_y_pred, y_pred_with_shuffle=BL2_y_pred_with_shuffle)
                    # cumulate URAS 
                    cum_our_model_URAS += our_URAS
                    cum_BL2_model_URAS += BL2_URAS
                # calculate the average over all unlearned models (the number of models of each type is `unlearn_times`)
                avg_our_model_URAS = cum_our_model_URAS / shuffle_times
                avg_BL2_model_URAS = cum_BL2_model_URAS / shuffle_times
                # append info 
                our_model_URAS_lst.append(avg_our_model_URAS)
                BL2_model_URAS_lst.append(avg_BL2_model_URAS)
            agg_our_model_RASI_lst.append([np.mean(our_model_URAS_lst), np.std(our_model_URAS_lst)])        
            agg_BL2_model_RASI_lst.append([np.mean(BL2_model_URAS_lst), np.std(BL2_model_URAS_lst)])
        
        df_means, df_stds = plot_evaluation(evaluation_metric='RASI', our_eval_res_lst=agg_our_model_RASI_lst, BL2_eval_res_lst=agg_BL2_model_RASI_lst, where_to_unl=where_to_unl, label=label)
        RASI_dict[label][where_to_unl] = [df_means, df_stds]


import pickle
with open(f'results_processed/RASI.dict', 'wb') as f:
    pickle.dump(RASI_dict, f)

# %% [markdown]
# ### XAI via SHAP

# %% [markdown]
# ### evaluation 4 & 5: SHAP Retention Index (SRI)& SHAP Distance-to-zero Index (SDI)

# %%
# import numpy as np 
# import shap 
# import pandas as pd 
# from typing import Tuple
# from torch.utils.data import Dataset, DataLoader


# from utils import CelebADataset, BL2_CelebADataset, Shuffle_CelebADataset

# def model_wrapper_for_XAI(inputs: np.ndarray, device: torch.device, classifier: nn.Module, backbone: ViTBackBone) -> np.ndarray:
#     backbone.eval()
#     classifier.eval()
    
#     with torch.no_grad():
#         inputs = torch.FloatTensor(inputs).to(device)
#         features = backbone(inputs)
#         outputs = classifier(features).argmax(dim=-1)
    
#     return outputs.cpu().numpy()



# def cal_SHAP(where_to_unl, device, classifier, backbone, train_attr_dict, test_attr_dict, train_area_dict, test_area_dict, model_type):
#     # Create a model that returns the output we want to explain
#     def model_to_explain(x):
#         x_tensor = torch.FloatTensor(x).to(device)
#         with torch.no_grad():
#             features = backbone(x_tensor)
#             outputs = classifier(features)  # Return logits, not argmax
#         return outputs.cpu().numpy()
    
#     # Load data
#     if model_type == 'ours':
#         train_dataset = Shuffle_CelebADataset(attr_dict=train_attr_dict, area_dict=train_area_dict, where_to_unl=where_to_unl, device=device)
#         test_dataset = Shuffle_CelebADataset(attr_dict=test_attr_dict, area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
#     elif model_type == 'BL2':
#         train_dataset=  BL2_CelebADataset(attr_dict=train_attr_dict, where_to_unl=where_to_unl, device=device)
#         test_dataset = BL2_CelebADataset(attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
#     elif model_type == 'ori':
#         train_dataset = CelebADataset(attr_dict=train_attr_dict, device=device)
#         test_dataset = CelebADataset(attr_dict=test_attr_dict, device=device)
    
#     # Get background and test data
#     train_dataloader = DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False)
#     test_dataloader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=False)
    
#     background_arr = next(iter(train_dataloader))[0].cpu().numpy()
#     test_arr = next(iter(test_dataloader))[0].cpu().numpy()
    
#     # Create explainer
#     # Create a partition explainer with superpixels
#     explainer = shap.PartitionExplainer(model_to_explain, shap.maskers.Partition(background_arr), 
#                                         image_data=True, n_segments=50)
#     shap_values = explainer(test_arr)
#     # explainer = shap.GradientExplainer(model_to_explain, background_arr)
#     # shap_values = explainer.shap_values(test_arr)
    
#     # Process SHAP values - for image data, shap_values is a list of arrays
#     # Each array corresponds to a class, so we take the predicted class or average
#     shap_values = np.array(shap_values)
    
#     # Calculate feature importance for regions of interest
#     fea_imp_val = 0
#     for idx, (img_name, area) in enumerate(test_area_dict.items()):
#         if len(area) != 4:
#             top_n, bottom_n, left_n, right_n, top, bottom, left, right = area
#             # If multi-class, sum across all classes or take specific class
#             fea_imp_val += np.abs(shap_values[:, idx, :, top:bottom, left:right]).mean(axis=(0,1)).sum()
#             fea_imp_val += np.abs(shap_values[:, idx, :, top_n:bottom_n, left_n:right_n]).mean(axis=(0,1)).sum()
#         else:
#             top, bottom, left, right = area
#             fea_imp_val += np.abs(shap_values[:, idx, :, top:bottom, left:right]).mean(axis=(0,1)).sum()
    
#     fea_imp_val /= len(test_area_dict)
    
#     return fea_imp_val, explainer



# # # evaluate unlearned feature's feature importance via SHAP
# # def cal_SHAP(where_to_unl: str, device: torch.device, classifier: nn.Module, backbone: ViTBackBone, train_attr_dict: dict, test_attr_dict: dict, train_area_dict: dict, test_area_dict: dict, model_type: str) -> Tuple[pd.DataFrame, np.ndarray, shap.Explainer]:
# #     def model_wrapper_for_SHAP(X):
# #         return model_wrapper_for_XAI(inputs=X, device=device, classifier=classifier, backbone=backbone)
    
    
# #     if model_type == 'ours':
# #         train_dataset = Shuffle_CelebADataset(attr_dict=train_attr_dict, area_dict=train_area_dict, where_to_unl=where_to_unl, device=device)
# #         test_dataset = Shuffle_CelebADataset(attr_dict=test_attr_dict, area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
# #     elif model_type == 'BL2':
# #         train_dataset=  BL2_CelebADataset(attr_dict=train_attr_dict, where_to_unl=where_to_unl, device=device)
# #         test_dataset = BL2_CelebADataset(attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
# #     elif model_type == 'ori':
# #         train_dataset = CelebADataset(attr_dict=train_attr_dict, device=device)
# #         test_dataset = CelebADataset(attr_dict=test_attr_dict, device=device)
    
# #     train_dataloader = DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False)
# #     test_dataloader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=False)
# #     background_arr = []
# #     test_arr = []
# #     for inputs, _ in train_dataloader:
# #         background_arr.append(inputs.detach().cpu().numpy())
# #     for inputs, _ in test_dataloader:
# #         test_arr.append(inputs.detach().cpu().numpy())
# #     background_arr = np.concatenate(background_arr, axis=0)
# #     test_arr = np.concatenate(test_arr, axis=0)
    
# #     explainer = shap.Explainer(model_wrapper_for_SHAP, background_arr)
# #     shap_values = explainer(test_arr)
# #     shap_values = np.abs(shap_values.values)
# #     # Calculate mean absolute SHAP values for feature importance
# #     # shap_values_array = np.abs(shap_values.values).mean(axis=0)
    
    
# #     fea_imp_val = 0
# #     for idx, (img_name, area) in enumerate(test_area_dict.items()):
# #         if len(area) != 4:
# #             top_n, bottom_n, left_n, right_n, top, bottom, left, right = area
# #             fea_imp_val += shap_values[idx, :, top:bottom, left:right].mean(axis=0).sum()
# #             fea_imp_val += shap_values[idx, :, top_n:bottom_n, left_n:right_n].mean(axis=0).sum()
# #         else:
# #             top, bottom, left, right = area
# #             fea_imp_val += shap_values[idx, :, top:bottom, left:right].mean(axis=0).sum()
# #     fea_imp_val /= len(test_area_dict)
    
    
# #     return fea_imp_val, explainer


# from utils import prep_data

# XAI_background_data_size = 1000
# XAI_test_data_size = 100
# SRI_dict, SDI_dict, XAI_df_dict = {}, {}, {}
# for label in label_lst:
#     SRI_dict[label], SDI_dict[label], XAI_df_dict[label] = {}, {}, {}
#     for where_to_unl in where_to_unl_lst:
        
#         if label == 'Big_Nose' and where_to_unl != 'nose': continue 
#         if label == 'Eyeglasses' and where_to_unl != 'eye': continue
#         if label == 'Narrow_Eyes' and where_to_unl != 'eye': continue
#         if label == 'Pointy_Nose' and where_to_unl != 'nose': continue
        
#         train_attr_dict, test_attr_dict, train_area_dict, test_area_dict = prep_data(label=label, where_to_unl=where_to_unl, retrain_or_shuffle='shuffle')
#         train_keys = list(train_attr_dict.keys())[:XAI_background_data_size]
#         test_keys = list(test_attr_dict.keys())[:XAI_test_data_size]
#         new_train_attr_dict = {k: train_attr_dict[k] for k in train_keys}
#         new_train_area_dict = {k: train_area_dict[k] for k in train_keys}
#         new_test_attr_dict = {k: test_attr_dict[k] for k in test_keys}
#         new_test_area_dict = {k: test_area_dict[k] for k in test_keys}
        
#         shap_before_unl, ori_shap_exper = cal_SHAP(where_to_unl=where_to_unl, device=device, classifier=ori_model_dict[label], backbone=vit_backbone, train_attr_dict=new_train_attr_dict, test_attr_dict=new_test_attr_dict, train_area_dict=new_train_area_dict, test_area_dict=new_test_area_dict, model_type='ori')
        
#         our_model_SRI_lst, BL2_model_SRI_lst = [], []
#         our_model_SDI_lst, BL2_model_SDI_lst = [], []
#         for unl_epoch in unl_epochs_lst:
#             avg_our_shap_lst, avg_BL2_shap_lst = [], []
#             for unlearn_times_idx in range(unlearn_times):
#                 our_model_classifier = our_model_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
#                 BL2_model_classifier = BL2_model_dict[label][where_to_unl][unl_epoch][unlearn_times_idx]
#                 our_shap_vals, our_shap_exper = cal_SHAP(where_to_unl=where_to_unl, device=device, classifier=our_model_classifier, backbone=vit_backbone, train_attr_dict=new_train_attr_dict, test_attr_dict=new_test_attr_dict, train_area_dict=new_train_area_dict, test_area_dict=new_test_area_dict, model_type='ours')
#                 BL2_shap_vals, BL2_shap_exper = cal_SHAP(where_to_unl=where_to_unl, device=device, classifier=BL2_model_classifier, backbone=vit_backbone, train_attr_dict=new_train_attr_dict, test_attr_dict=new_test_attr_dict, train_area_dict=new_train_area_dict, test_area_dict=new_test_area_dict, model_type='BL2')
#                 avg_our_shap_lst.append(our_shap_vals)
#                 avg_BL2_shap_lst.append(BL2_shap_vals)
                
#             our_shap_mean, our_shap_std = np.mean(avg_our_shap_lst), np.std(avg_our_shap_lst)
#             BL2_shap_mean, BL2_shap_std = np.mean(avg_BL2_shap_lst), np.std(avg_BL2_shap_lst)
#             our_model_SRI_lst.append([our_shap_mean/shap_before_unl, our_shap_std/shap_before_unl])
#             BL2_model_SRI_lst.append([BL2_shap_mean/shap_before_unl, BL2_shap_std/shap_before_unl])
#             our_model_SDI_lst.append([our_shap_mean, our_shap_std])
#             BL2_model_SDI_lst.append([BL2_shap_mean, BL2_shap_std])
            
#         SRI_df_means, SRI_df_stds = plot_evaluation(evaluation_metric='SRI', our_eval_res_lst=our_model_SRI_lst, BL2_eval_res_lst=BL2_model_SRI_lst, where_to_unl=where_to_unl, label=label)
#         SDI_df_means, SDI_df_stds = plot_evaluation(evaluation_metric='SDI', our_eval_res_lst=our_model_SDI_lst, BL2_eval_res_lst=BL2_model_SDI_lst, where_to_unl=where_to_unl, label=label)
#         SRI_dict[label][where_to_unl] = [SRI_df_means, SRI_df_stds]
#         SDI_dict[label][where_to_unl] = [SDI_df_means, SDI_df_stds]



# %% [markdown]
# 


