import numpy as np
import matplotlib.pyplot as plt
import pickle
import argparse
from scipy.spatial import KDTree
from tqdm import tqdm

from utils.explanations import calculate_gt_astuteness

parser = argparse.ArgumentParser()
parser.add_argument('--datatype', type=str,
                    choices=['orange_skin', 'XOR', 'nonlinear_additive', 'switch'], default='switch')
parser.add_argument('--run_times', type=int, default=10)
parser.add_argument('--radius_range', default=np.arange(0, 5, 1))
parser.add_argument('--epsilon_range', default=np.arange(0, 1, 0.1))
parser.add_argument('--prop_points', type=float, default=1)
parser.add_argument('--calculate', dest='calculate', action='store_true')
parser.add_argument('--no-calculate', dest='calculate', action='store_false')
parser.set_defaults(calculate=False)

args = parser.parse_args()

ks = {'orange_skin': 4, 'XOR': 2, 'nonlinear_additive': 4, 'switch': 5}

data_dict = pickle.load(open('data/' + args.datatype + 'gt.pk', 'rb'))

x_train, _, x_val, _, datatype_val, datatype_train, input_shape = data_dict['x_train'], data_dict['y_train'], \
                                 data_dict['x_val'], data_dict['y_val'], \
                                 data_dict['datatype_val'], data_dict['datatype_train'],\
                                 data_dict['input_shape']
range_indices = np.random.choice(len(x_train), 5000, replace=False)
x_train = x_train[range_indices]

save_astuteness_file = 'plots/gt_' + args.datatype + '_astuteness.pk'
if args.datatype == 'switch':
    datatype_train = datatype_train[range_indices]
    gt_explanations_types = {'orange_skin': [1, 1, 1, 1, 0, 0, 0, 0, 0, 1],
                             'nonlinear_additive': [0, 0, 0, 0, 1, 1, 1, 1, 0, 1]}
    gt_explanations = [gt_explanations_types[dtype] for dtype in datatype_train]
elif args.datatype in ['orange_skin', 'nonlinear_additive']:
    gt_explanations_types = [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
    gt_explanations = [gt_explanations_types for i in range(len(x_train))]
else:
    gt_explanations_types = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
    gt_explanations = [gt_explanations_types for i in range(len(x_train))]

gt_explanations = np.array(gt_explanations)
kdtree = KDTree(x_train)
if args.calculate:
    total_astuteness = np.zeros(shape=(len(args.radius_range), len(args.epsilon_range)))
    for i in range(len(args.radius_range)):
        for j in tqdm(range(len(args.epsilon_range))):
            _, total_astuteness[i, j] = calculate_gt_astuteness(x_train, gt_explanations,
                                            num_points=int(args.prop_points * len(x_train)),
                                            ball_r=args.radius_range[i],
                                            epsilon=args.epsilon_range[j],
                                            kdtree=kdtree)
    pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))
else:
    total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))

image_name = 'plots/gt_' + args.datatype + '_astuteness.PNG'
for i in range(len(args.radius_range)):
    plt.errorbar(x=args.epsilon_range, y=total_astuteness[i, :], yerr=0,
                 label='radius: ' + str(args.radius_range[i]))
plt.legend()
plt.savefig(image_name)
r = 3