# %%
import lime
import os
from lime import lime_presampled
import pandas as pd

from tqdm.auto import tqdm
import torch
import numpy as np
from lime.lime_presampled import LimePreSampledExplainer
from lime.shap_presampled import SHAPPresampledExplainer

import spacy
import argparse
def shap_kernel(data):
    def kernel(x):
        if np.all(x) or np.all(x == 0):
            return 1e6
        return 1
    weights = np.apply_along_axis(kernel, 1, data)
    return weights

exp_method = 'shap'
def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--subjects', nargs='+', default=os.listdir('samples_pools'))
    args,_ = parser.parse_known_args()
    return args

args = get_args()
# exp_method = args.exp_method
subjects = args.subjects


nlp = spacy.load('en_core_web_lg')


# model_name = 'LLM-Research/Meta-Llama-3.1-8B-Instruc'
# model_short_name = 'llama_8B'
# samples_df = pd.read_csv(f'./SST2/sst_test_samples_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)

# subjects = os.listdir('samples_pools')

# subjects = ['high_school_world_history']
# %%

for subject in tqdm(sorted(subjects)):
    sample_dir = os.path.join(f'{exp_method}_samples/{subject}')
    os.makedirs(f'./{exp_method}_res/{subject}', exist_ok=True)
    for file in tqdm(os.listdir(sample_dir)):  
        if '0001' in file:
            continue
        print(file)
        if not file.startswith(f'{subject}_perturb_') and 'backup' not in file and 'counted' not in file:
            continue
        
        # print(f"Filename: {file}")

        model_short_name = os.path.splitext(file)[0][file.rfind('perturb')+len('perturb')+1:]
        
        # print(model_short_name)
        
        if os.path.exists(f'./{exp_method}_res/{subject}/{model_short_name}.csv'):
            exp_df = pd.read_csv(f'./{exp_method}_res/{subject}/{model_short_name}.csv', sep=',',index_col=0,keep_default_na=False)
            exp_df['weight'] = exp_df['weight'].apply(lambda x: eval(x))
            exp_df = exp_df.to_dict(orient='records')
        else:
            exp_df = []
        
        samples_df = pd.read_csv(f'./{exp_method}_samples/{subject}/{file}', sep='\t',index_col=None,keep_default_na=False)
        # samples_df
        max_idx = max(samples_df['question_index'])
        # max_idx
        # pbar = tqdm(range(len(exp_df),max_idx+1), dynamic_ncols=True)
        for idx in range(len(exp_df),max_idx+1):
            local_df = samples_df[samples_df['question_index'] == idx]
            local_df.reset_index(drop=True, inplace=True)
            
            text = local_df.loc[0,'question']
            logits = local_df[['logits_A', 'logits_B', 'logits_C', 'logits_D']].values

            if min(logits[0]) <= -1000.0:
                print(f"{subject} {model_short_name} {idx} has not been fully processed")
                break

            probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
            data = local_df['binary_representation'].values
            data = np.array([list(map(int, list(x))) for x in data])
            data = data.astype(int)

            for i in range(len(data)):
                if all(data[i]):
                    data[0], data[i] = data[i].copy(), data[0].copy()
                    probs[0], probs[i] = probs[i].copy(), probs[0].copy()
                    break
            if not all(data[0]):
                print(f'Error: {idx}')
                raise ValueError("Each data array should have a all 1 row")
            explainer = SHAPPresampledExplainer(class_names=['A', 'B', 'C', 'D'])

            if len(data) == 1<<data.shape[-1]:
                explanation = explainer.explain_instance(data, probs,labels = (0,1,2,3))
            else:
                explanation = explainer.explain_instance(data, probs, shap_kernel=shap_kernel,labels = (0,1,2,3))
            
            if not np.any([np.all(row == 0) for row in data]):
                raise ValueError("Each data array should have a all 0 row")
            
            weight = [
                [x[1] for x in sorted(explanation.as_list(i),key=lambda x: x[0])] 
                for i in range(4)]      
            intercept = [explanation.intercept[i] for i in range(4)]
            pred = explanation.local_pred[0]
            exp_df.append({
                "sentence": text,
                "output": probs[0].argmax(),
                "weight": weight,
                "intercept": intercept,
            })
        if len(exp_df) == 0:
            continue
        exp_df = pd.DataFrame(exp_df)
        exp_df.to_csv(f'./{exp_method}_res/{subject}/{model_short_name}.csv')
        # print(f'Negative inputs: {exp_df["output"].sum()}')



# # %%

# %%
