# %%
import lime
import os
# from lime import lime_presampled
import pandas as pd

from tqdm.auto import tqdm
import torch
import numpy as np
from lime.shap_presampled import SHAPPresampledExplainer

# import spacy
# nlp = spacy.load('en_core_web_lg')



def shap_kernel(data):
    def kernel(x):
        if np.all(x) or np.all(x == 0):
            return 1e6
        return 1
    weights = np.apply_along_axis(kernel, 1, data)
    return weights
# model_name = 'LLM-Research/Meta-Llama-3.1-8B-Instruc'
# model_short_name = 'llama_8B'
# samples_df = pd.read_csv(f'./SST2/sst_test_samples_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)

# %%
for file in os.listdir('shap_samples'):
    if (not file.startswith('sst_test_samples_')) or 'backup' in file or 'counted' in file:
        continue
    print(f"Filename: {file}")
    model_short_name = os.path.splitext(file)[0][file.rfind('samples')+len('samples')+1:]
    print(model_short_name)
    
    exp_df = []
    # if os.path.exists(f'./shap_res/sst_test_{model_short_name}_exp.csv'):
    #     exp_df = pd.read_csv(f'./shap_res/sst_test_{model_short_name}_exp.csv', sep=',',index_col=0,keep_default_na=False)
    #     exp_df['weight'] = exp_df['weight'].apply(lambda x: eval(x))
    #     exp_df = exp_df.to_dict(orient='records')
    # else:
    #     exp_df = []
    samples_df = pd.read_csv(f'./shap_samples/{file}', sep='\t',index_col=0,keep_default_na=False)
    # samples_df
    if min(samples_df['logits_positive'].values) <= -1000:
        print(f"File {file} has logits <= -1000, be careful")
        # continue
    max_idx = max(samples_df['sentence_index'])
    # max_idx
    pbar = tqdm(range(len(exp_df),max_idx+1), dynamic_ncols=True)
    base_logits = samples_df.loc[0,['logits_positive', 'logits_negative']].values.astype(float)

    base_probs = torch.softmax(torch.tensor(base_logits), dim=0).numpy()

    for idx in pbar:
        local_df = samples_df[samples_df['sentence_index'] == idx]
        local_df.reset_index(drop=True, inplace=True)


        text = local_df.loc[0,'sentence']
        logits = local_df[['logits_positive', 'logits_negative']].values

        if np.min(logits) <= -1000.0:
            break
            
        probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
        data = local_df['binary_representation'].values
        data = np.array([list(map(int, list(x))) for x in data])
        data = data.astype(int)
        for i in range(len(data)):
            if all(data[i]):
                data[0], data[i] = data[i].copy(), data[0].copy()
                probs[0], probs[i] = probs[i].copy(), probs[0].copy()
                break

        if not all(data[0]):
            print(f'Error: {idx}')
            raise ValueError("Each data array should have a all 1 row")

        if not np.any([np.all(row == 0) for row in data]):
            data = np.concatenate((data, np.zeros((1, len(data[0])), dtype=int)), axis=0)
            probs = np.concatenate((probs, base_probs.reshape(1,-1)), axis=0)

        
        
        explainer = SHAPPresampledExplainer(class_names=['positive', 'negative'])

        if len(data) == 1<<data.shape[-1]:
            explanation = explainer.explain_instance(data, probs)
        else:
            explanation = explainer.explain_instance(data, probs, shap_kernel=shap_kernel)

        weight = [x[1] for x in sorted(explanation.as_list(0),key=lambda x: x[0])]
        intercept = explanation.intercept[0]
        pred = explanation.local_pred[0]
        exp_df.append({
            "sentence": text,
            "output": probs[0].argmax(),
            "weight": weight,
            "intercept": intercept,
        })
    if len(exp_df) == 0:
        continue
    exp_df = pd.DataFrame(exp_df)
    exp_df.to_csv(f'./shap_res/sst_test_{model_short_name}_exp.csv')
    print(f'Negative inputs: {exp_df["output"].sum()}')



# %%
