import matplotlib.pyplot as plt
import sys
import numpy as np
import math

files = ['results_EXP3_no_attack_hard.txt', 'results_eps=0.1.txt', 'results_eps=0.25.txt', 'results_eps=0.4.txt']
labels = ['Exp3 (No Attack)', r'Exp3 $(\epsilon=0.1)$', r'Exp3 $(\epsilon=0.25)$', r'Exp3 $(\epsilon=0.4)$']

plt.figure(2)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
plt.gca().set_color_cycle(colors[1:])

for i in range(len(files)):
	path = 'results/'+files[i]

	f = open(path, 'r')
	Ts = np.array([float(x) for x in np.array(f.readline().rstrip('\n').split(','))])
	Ntrial = int(f.readline().rstrip('\n'))
	attack_costs = {}
	N_dagger = {}
	for trial in range(Ntrial):
		N_dagger[trial] = np.array([float(x) for x in np.array(f.readline().rstrip('\n').split(','))])
		attack_costs[trial] = np.array([float(x) for x in np.array(f.readline().rstrip('\n').split(','))])

	N_dagger_std = [0 for _ in range(len(Ts))]
	attack_costs_std = [0 for _ in range(len(Ts))]
	N_dagger_mean = [0 for _ in range(len(Ts))]
	attack_costs_mean = [0 for _ in range(len(Ts))]

	for idx, T in enumerate(Ts):
		N_dagger_std[idx] = np.std([N_dagger[trial][idx] for trial in range(Ntrial)])/np.sqrt(Ntrial)
		N_dagger_mean[idx] = np.mean([N_dagger[trial][idx] for trial in range(Ntrial)]) 
		attack_costs_std[idx] = np.std([attack_costs[trial][idx] for trial in range(Ntrial)])/np.sqrt(Ntrial)
		attack_costs_mean[idx] = np.mean([attack_costs[trial][idx] for trial in range(Ntrial)])
		

	print(N_dagger_mean, N_dagger_std)
	print(attack_costs_mean, attack_costs_std)
	plt.figure(1)
	plt.plot(np.log(Ts), np.log(Ts-N_dagger_mean),  marker='o', label=labels[i])
	plt.xlabel(r'$\log T$', fontsize=15)
	plt.ylabel(r'$\log (T-N_T(a^\dagger))$', fontsize=15)
	

	plt.figure(2)
	if i == 0:
		continue
	plt.plot(np.log(Ts), np.log(attack_costs_mean),  marker='o', label=labels[i])
	plt.xlabel(r'$\log T$', fontsize=15)
	plt.ylabel(r'$\log C_T$', fontsize=15)

plt.figure(1)
plt.plot(np.log(Ts), np.log(Ts), '--', label=r'$y=x$', color = 'k')
plt.grid()
plt.legend(fontsize=15)
plt.tight_layout()
plt.tick_params(axis='both', labelsize=15)

plt.figure(2)
plt.plot(np.log(Ts), np.log(Ts), '--', label=r'$y=x$', color = 'k')
plt.grid()
plt.legend(fontsize=15)
plt.tight_layout()
plt.tick_params(axis='both', labelsize=15)

plt.show()