import numpy as np
import argparse
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import shap

from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.callbacks import ModelCheckpoint

from utils.explanations import calculate_stability, calculate_robust_astute_sampled

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--datatype', type=str,
                        choices=['rice'], default='rice')
    parser.add_argument('--run_times', type=int, default=10)
    parser.add_argument('--radius_range', default=np.arange(2, 20, 2))
    parser.add_argument('--epsilon_range', default=np.arange(0.01, 0.11, 0.01))
    parser.add_argument('--num_samples', type=int, default=1000)
    parser.add_argument('--prop_points', type=float, default=1)
    parser.add_argument('--calculate', dest='calculate', action='store_true')
    parser.add_argument('--no-calculate', dest='calculate', action='store_false')
    parser.set_defaults(calculate=True)

    args = parser.parse_args()

    blackbox_path = 'models/' + args.datatype + '_blackbox.hdf5'
    rice_pd = pd.read_excel('data/Rice_Osmancik_Cammeo_Dataset.xlsx')
    data = rice_pd.values[:, :-1]
    labels = rice_pd.values[:, -1]
    labels[labels == 'Cammeo'] = 0
    labels[labels == 'Osmancik'] = 1
    x_train, x_val, y_train, y_val = train_test_split(data, labels, test_size=0.33, random_state=42)
    x_train = StandardScaler().fit_transform(x_train)
    x_val = StandardScaler().fit_transform(x_val)
    input_shape = x_train.shape[-1]

    save_astuteness_file = 'plots/shap_' + args.datatype + '_astuteness.pk'
    save_dist_exp_file = 'plots/shap_' + args.datatype + '_expdistance.pk'

    if args.calculate:
        total_astuteness = np.zeros(shape=(args.run_times, len(args.radius_range), len(args.epsilon_range)))
        total_exp_distance = np.zeros(shape=(args.run_times, len(args.radius_range),
                                             len(args.epsilon_range), args.num_samples))
        activation = 'relu'

        model_input = Input(shape=(input_shape,), dtype='float32')

        net = Dense(32, activation=activation, name='dense1',
                    kernel_regularizer=regularizers.l2(1e-3))(model_input)

        bbox_preds = Dense(1, activation='sigmoid', name='dense3',
                           kernel_regularizer=regularizers.l2(1e-3))(net)
        bbox_model = Model(model_input, bbox_preds)
        bbox_model.load_weights(blackbox_path,
                                by_name=True)

        background = x_train[np.random.choice(len(x_train), 100, replace=False)]
        explainer = shap.GradientExplainer(bbox_model, background)
        for i in range(args.run_times):
            fname = 'explained_weights/shap/' + 'shap_' + args.datatype + '_' + str(i) + '.gz'
            explanations = np.loadtxt(fname, delimiter=',')
            print('Completing Run ' + str(i + 1) + ' of ' + str(args.run_times))
            for j in range(len(args.radius_range)):
                for k in tqdm(range(len(args.epsilon_range))):
                    _, total_astuteness[i, j, k], _, total_exp_distance[i, j, k, :] = \
                        calculate_robust_astute_sampled(data=x_val,
                                                       explainer=explainer,
                                                       explainer_type='shap',
                                                       explanation_type='attribution',
                                                       data_explanation=explanations,
                                                       ball_r=args.radius_range[j],
                                                       epsilon=args.epsilon_range[k],
                                                       num_points=int(
                                                           args.prop_points * len(x_val)),
                                                       num_samples=args.num_samples)
        pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))
        pickle.dump(total_exp_distance, open(save_dist_exp_file, 'wb'))
    else:
        total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))
    astuteness_mean = total_astuteness.mean(axis=0)
    astuteness_std = total_astuteness.std(axis=0)
    image_name = 'plots/shap_' + args.datatype + '_astuteness.PNG'
    fig, ax = plt.subplots()
    for i in range(len(args.radius_range)):
        ax.errorbar(x=args.epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],
                    label='radius: ' + str(args.radius_range[i]))
    plt.legend()
    plt.savefig(image_name)
    r = 3
