# %%
import lime
import os
from lime import lime_presampled
import pandas as pd

from tqdm.auto import tqdm
import torch
import numpy as np
from lime.lime_presampled import LimePreSampledExplainer

import spacy
nlp = spacy.load('en_core_web_lg')


# model_name = 'LLM-Research/Meta-Llama-3.1-8B-Instruc'
# model_short_name = 'llama_8B'
# samples_df = pd.read_csv(f'./SST2/sst_test_samples_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)



# %%
for file in os.listdir('lime_samples'):
    if not file.startswith('sst_test_samples_') and 'backup' not in file and 'counted' not in file:
        continue
    print(f"Filename: {file}")
    model_short_name = os.path.splitext(file)[0][file.rfind('samples')+len('samples')+1:]
    print(model_short_name)
    
    if os.path.exists(f'./lime_res/sst_test_{model_short_name}_exp.csv'):
        exp_df = pd.read_csv(f'./lime_res/sst_test_{model_short_name}_exp.csv', sep=',',index_col=0,keep_default_na=False)
        exp_df['weight'] = exp_df['weight'].apply(lambda x: eval(x))
        exp_df = exp_df.to_dict(orient='records')
    else:
        exp_df = []
    samples_df = pd.read_csv(f'./lime_samples/{file}', sep='\t',index_col=0,keep_default_na=False)
    # samples_df
    max_idx = max(samples_df['sentence_index'])
    # max_idx
    pbar = tqdm(range(len(exp_df),max_idx+1), dynamic_ncols=True)

    for idx in pbar:
        local_df = samples_df[samples_df['sentence_index'] == idx]
        local_df.reset_index(drop=True, inplace=True)
        
        text = local_df.loc[0,'sentence']
        logits = local_df[['logits_positive', 'logits_negative']].values

        if min(logits[0]) <= -1000.0:
            break
        probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
        data = local_df['binary_representation'].values
        data = np.array([list(map(int, list(x))) for x in data])
        data = data.astype(int)
        for i in range(len(data)):
            if all(data[i]):
                data[0], data[i] = data[i].copy(), data[0].copy()
                probs[0], probs[i] = probs[i].copy(), probs[0].copy()
                break
        if not all(data[0]):
            print(f'Error: {idx}')
            raise ValueError("Each data array should have a all 1 row")
        
        explainer = LimePreSampledExplainer(class_names=['positive', 'negative'])
        explanation = explainer.explain_instance(data, probs, num_features=len(data[0]),labels=(0,))
        weight = [x[1] for x in sorted(explanation.as_list(0),key=lambda x: x[0])]
        intercept = explanation.intercept[0]
        pred = explanation.local_pred[0]
        exp_df.append({
            "sentence": text,
            "output": probs[0].argmax(),
            "weight": weight,
            "intercept": intercept,
        })
    if len(exp_df) == 0:
        continue
    exp_df = pd.DataFrame(exp_df)
    exp_df.to_csv(f'./lime_res/sst_test_{model_short_name}_exp.csv')
    print(f'Negative inputs: {exp_df["output"].sum()}')



# %%
