import pandas as pd
sst_test = pd.read_csv('SST2/sst_test.csv', sep='\t')

import numpy as np
samples = pd.DataFrame(columns=['sentence_index','sentence','binary_representation','sample_sentence'])

ssamples = []
len(sst_test)
from tqdm.auto import tqdm
import spacy
nlp = spacy.load('en_core_web_lg')

ssamples = []
cnt = 0
for i in tqdm(range(len(sst_test))):
    np.random.seed(42)
    text = sst_test.iloc[i]['sentence']
    doc = nlp(text)
    tokens = [token for token in doc]
    if len(doc) <= 10:
        for s in range(1<<len(doc)):
            sample = ""
            for j in range(len(doc)):
                if (s & (1<<j)):
                    sample += tokens[j].text + tokens[j].whitespace_
            sample = sample.strip()
            ssamples.append({'sentence_index':i, 'sentence':text, 'binary_representation':''.join(reversed(np.binary_repr(s,width=len(doc)))), 'sample_sentence':sample})
            
    else:
        num_interp_features = len(tokens)
        num_features_list = np.arange(num_interp_features, dtype=float)
        denom = num_features_list * (num_interp_features - num_features_list)
        probs = np.array((num_interp_features - 1)) / denom[1:]

        binary_samples = []
        def add_sample(b_sample):
            b_sample = ''.join([str(i) for i in b_sample])
            sample = ""
            for j in range(len(doc)):
                if b_sample[j] == '1':
                    sample += tokens[j].text + tokens[j].whitespace_
            sample = sample.strip()
            ssamples.append({"sentence_index":i, 'sentence':text, 'binary_representation':b_sample, 'sample_sentence':sample})
        add_sample('1'*(len(doc)))
        add_sample('0'*(len(doc)))
        for _ in range(1000):
            ones = np.random.choice(num_interp_features-1, p=probs/np.sum(probs))+1
            b_sample = np.zeros(len(doc),dtype=int)
            b_sample[np.random.choice(len(doc), ones, replace=False)] = 1
            add_sample(b_sample)

        

samples = pd.DataFrame(ssamples)
samples.to_csv('shap_samples/sst_test_samples.csv', sep='\t')
samples.to_json('shap_samples/sst_test_samples.json', orient='records')
