## imports
import numpy as np
import pandas as pd
import random
from Faithfulness import Faithfulness
import scipy
import os
import click
from preprocessing import load_data, OHEWrapper
import warnings
warnings.filterwarnings("ignore")


@click.command()
@click.option('--dataset', default='adult', help='Name of the dataset to use')
@click.option('--explainer', default='KernelSHAP', help='Explanation to use')
@click.option('--seed', default=9, help='Seed for reproducibility')
def run_experiment(dataset, seed, explainer):
    np.random.seed(9)
    random.seed(9)
    Xtrain, Xexpl, ytrain, yexpl, categorical, _, task = load_data(dataset, 9)
    if len(Xexpl)> 2000:
        index = np.random.choice(len(Xexpl), 2000, replace=False)
        Xexpl = Xexpl.iloc[index]
        yexpl = yexpl[index]
    wrapper = OHEWrapper(categorical, task, seed, model='svm')
    wrapper.fit(Xtrain, ytrain)
    function = wrapper.predict
    
    numerical = np.setdiff1d(Xexpl.columns.values, categorical)
    df_info = {"categorical": categorical, "numerical": numerical}
    df = pd.read_csv(f'./explanations/{dataset}.csv')
    df.set_index("id", inplace=True)
    Z = df.iloc[:,1:]
    indexes = np.intersect1d(Xexpl.index, Z.index)

    Xtmp = Xexpl.loc[indexes]
    ytmp = function(Xtmp)[:,1] if task == 'classification' else function(Xtmp)
    Ztmp = Z.loc[indexes]
    faith = Faithfulness(function, Xtrain, df_info, seed = seed)
    for idx in range(len(Xtmp)):
        x, y, z = Xtmp.iloc[idx:idx+1], ytmp[idx], Ztmp.iloc[idx:idx+1]
        suff, dist = faith.compute_sufficiency(x, y, z)
        nec, _ = faith.compute_necessity(x, y, z, dist)
        faithfulness = scipy.stats.hmean([1-np.exp(-nec), np.exp(-suff)], axis=0)
        
        values = [[Xtmp.index[idx], seed, faithfulness, suff, nec]]
        df = pd.DataFrame(data=values, columns=["id", "seed", "faithfulness", "sufficiency", "necessity"])
        if os.path.exists(f'./explanations/{dataset}_faithfulness.csv'):
            df.to_csv(f'./explanations/{dataset}_faithfulness.csv', mode='a', index=False, header=False)
        else:
            df.to_csv(f'./explanations/{dataset}_faithfulness.csv', mode='w', index=False, header=True)
    
    

if __name__ == '__main__':
    run_experiment()
    