import numpy as np
import argparse

from utils.explanations import calculate_correlations

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='switch')
    parser.add_argument('--run_times', type=int, default=10)
    args = parser.parse_args()

    explanations = []
    for i in range(args.run_times):
        fname = 'explained_weights/shap/' + 'shap_' + args.datatype + '_' + str(i) + '.gz'
        explanations.append(np.loadtxt(fname, delimiter=','))
    stability_mean, stability_std = calculate_correlations(explanations, explanation_type='attribution')
    print('Stability Mean: ' + str(stability_mean) + ' Stability Std: ' + str(stability_std))
    r = 3