import numpy as np
import argparse
import matplotlib
from matplotlib import pyplot as plt
import math
import os

matplotlib.rcParams.update({'font.size': 18})

parser = argparse.ArgumentParser(description='Settings')
#parser.add_argument('--row', default=0, type=int, help='row of the epsilon and sampling table')
parser.add_argument('--dataset', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])

args = parser.parse_args()

#row = args.row

dataset = args.dataset

best_acc_path = f"./Plots/mean_ci_extra_trials_dataset={args.dataset}_best_acc.pdf"
final_acc_path = f"./Plots/mean_ci_extra_trials_dataset={args.dataset}_final_acc.pdf"

Bounds = [0.953,0.881,0.731]

colours = ['darkblue', 'blue', 'dodgerblue']

x_vals = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]

plt.figure()

best_accs_rows_mean = []
final_accs_rows_mean = []

best_accs_rows_ci = []
final_accs_rows_ci = []

for row in [0,1,2]:
  best_accs = []
  final_accs = []

  for trial in [1,2,3,4,5]:
    best_acc = []
    final_acc = []
    for x in range(9):
      path = f"./Accuracy_Results/trial={trial}_row={row}_col={x}_dataset={dataset}.npy"
      if os.path.exists(path):
        data = np.load(path)
        best_acc.append(data[0])
        final_acc.append(data[1])
      else:
        print("missing " + path)

    
    best_accs.append(best_acc)
    final_accs.append(final_acc)


  #best_acc_path = f"./Plots/mean_ci_extra_trials_row={args.row}_dataset={args.dataset}_best_acc.pdf"
  #final_acc_path = f"./Plots/mean_ci_extra_trials_row={args.row}_dataset={args.dataset}_final_acc.pdf"

  best_accs_np = np.array(best_accs)
  final_accs_np = np.array(final_accs)

  best_accs_np_mean = np.mean(best_accs_np, axis = 0)
  best_accs_np_std = np.std(best_accs_np, axis = 0)
  best_accs_np_ci = (1.96 / math.sqrt(len(best_accs_np)))* best_accs_np_std

  best_accs_rows_mean.append(best_accs_np_mean)
  best_accs_rows_ci.append(best_accs_np_ci)

  final_accs_np_mean = np.mean(final_accs_np, axis = 0)
  final_accs_np_std = np.std(final_accs_np, axis = 0)
  final_accs_np_ci = (1.96 / math.sqrt(len(final_accs_np)))* final_accs_np_std

  final_accs_rows_mean.append(final_accs_np_mean)
  final_accs_rows_ci.append(final_accs_np_ci)


plt.figure()

for i in [0,1,2]:
  plt.plot(x_vals,best_accs_rows_mean[i], color = colours[i], label = f"Bound = {Bounds[i]}")
  plt.fill_between(x_vals, best_accs_rows_mean[i] - best_accs_rows_ci[i], best_accs_rows_mean[i] + best_accs_rows_ci[i], color = colours[i], alpha = 0.2)

plt.xlabel("$P_{x^*}(1)$ Value")
plt.ylabel("Best Accuracy")
plt.legend()

'''
if dataset == "svhn_ext":
  plt.title(f"SVHN Extended Best Accuracy: Bound = {Bounds[row]}")
else:
  plt.title(f"{dataset} Best Accuracy: Bound = {Bounds[row]}")
#plt.show()
'''
plt.tight_layout()
plt.savefig(best_acc_path)

plt.close()

plt.figure()

for i in [0,1,2]:
  plt.plot(x_vals,final_accs_rows_mean[i], color = colours[i], label = f"Bound = {Bounds[i]}")
  plt.fill_between(x_vals, final_accs_rows_mean[i] - final_accs_rows_ci[i], final_accs_rows_mean[i] + final_accs_rows_ci[i], color = colours[i], alpha = 0.2)

plt.xlabel("$P_{x^*}(1)$ Value")
plt.ylabel("Final Accuracy")
plt.legend()

'''
if dataset == "svhn_ext":
  plt.title(f"SVHN Extended Final Accuracy: Bound = {Bounds[row]}")
else:
  plt.title(f"{dataset} Final Accuracy: Bound = {Bounds[row]}")
#plt.show()
'''
plt.tight_layout()
plt.savefig(final_acc_path)
