import numpy as np
import pandas as pd
import click
import shap
from preprocessing import load_data, OHEWrapper
import random
import os

@click.command()
@click.option('--dataset', default='adult', help='Name of the dataset to use')
@click.option('--seed', default=9, help='Seed for reproducibility')
@click.option('--n_samples', default=100, help='Number of background samples to use in the explainer')
def run_experiment(dataset, seed, n_samples):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, cat_idx, task = load_data(dataset, 9)
    if len(Xexpl)> 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
        
    starting_features = ["id","seed",]

    idx_back = np.random.choice(Xtrain.index, 100, replace=False)        
    background = Xtrain.loc[idx_back]
    nf = int(min(Xexpl.shape[1]/2, 10))
    explainer = shap.KernelExplainer(model=function, data=background, seed=seed)
    for idx in Xexpl.index:
        shaps = explainer.shap_values(X=Xexpl.loc[idx], nsamples=n_samples, l1_reg=f"num_features({str(nf)})")
        shaps = shaps[:,1].reshape((shaps.shape[0],)) if shaps.ndim == 2 else shaps
        values = np.hstack([idx, seed, shaps]).reshape(1,-1)
        df = pd.DataFrame(data=values, columns = np.hstack([starting_features, Xtrain.columns]))
        if os.path.exists(f'./explanations/{dataset}_stability.csv'):
            df.to_csv(f'./explanations/{dataset}_stability.csv', mode='a', index=False, header=False)
        else:
            df.to_csv(f'./explanations/{dataset}_stability.csv', mode='w', index=False, header=True)
            
if __name__ == '__main__':
    run_experiment()