import os
import os
data_path = "mmlu/data"
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(data_path, "test")) if "_test.csv" in f])

import yaml
import json
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import spacy
data_path = "mmlu/data"
nlp=spacy.load('en_core_web_lg')
def perturb(text):
    np.random.seed(42)
    doc = nlp(text)
    tokens = [token for token in doc]
    samples = []
    if len(doc) <= 10:
        for s in range(1<<len(doc)):
            sample = ""
            for j in range(len(doc)):
                if (s & (1<<j)):
                    sample += tokens[j].text + tokens[j].whitespace_
            sample = sample.strip()
            samples.append((''.join(reversed(np.binary_repr(s,width=len(doc)))), sample))
            
    else:
        num_interp_features = len(tokens)
        num_features_list = np.arange(num_interp_features, dtype=float)
        denom = num_features_list * (num_interp_features - num_features_list)
        probs = np.array((num_interp_features - 1)) / denom[1:]

        binary_samples = []
        def gen_sample(b_sample):
            b_sample = ''.join([str(i) for i in b_sample])
            if b_sample not in binary_samples:
                binary_samples.append(b_sample)
            else:
                return
            sample = ""
            for j in range(len(doc)):
                if b_sample[j] == '1':
                    sample += tokens[j].text + tokens[j].whitespace_
            sample = sample.strip()
            samples.append((b_sample, sample))
        gen_sample(np.ones(len(doc), dtype=int))
        gen_sample(np.zeros(len(doc), dtype=int))
        for _ in range(1000):
            ones = np.random.choice(num_interp_features-1, p=probs/np.sum(probs))+1
            b_sample = np.zeros(len(doc),dtype=int)            
            b_sample[np.random.choice(len(doc), ones, replace=False)] = 1
            gen_sample(b_sample)
    return samples
    
import os
from tqdm.auto import tqdm
import numpy as np
subjects = ["high_school_world_history", "high_school_computer_science", "high_school_chemistry", "high_school_microeconomics", "high_school_psychology"]

for subject in subjects:
    results_path = os.path.join("shap_samples", subject)
    os.makedirs(results_path, exist_ok=True)
    test_df = pd.read_csv(os.path.join(data_path, "test", subject + "_test.csv"), names=("Question", "A", "B", "C", "D", "Answer"))
    samples =[]
    for idx in tqdm(range(len(test_df))):
        ls = perturb(test_df.iloc[idx]['Question'])
        for s in ls:
            samples.append({'question_index':idx, 'question':test_df.iloc[idx]['Question'], 'binary_representation':s[0], 'sample_question':s[1]})
    samples = pd.DataFrame(samples)
    samples.to_csv(os.path.join(results_path, f"{subject}_perturb.csv"), index=False, sep='\t')
