import matplotlib.pyplot as plt
import pickle as pkl
import numpy as np

epoch_numbers = []
num_linear_regions = []
dists = []
for i in range(20):
    with open("data2/tractrix_%d.pkl"%(i), "rb") as f_data:
        tract_d = pkl.load(f_data)
    epoch_numbers = tract_d[0]
    num_linear_regions.append(tract_d[1])
    dists.append(tract_d[2])


sphere_epoch_numbers = []
sphere_num_linear_regions = []
sphere_dists = []
for i in range(20):
    with open("data2/sphere_%d.pkl"%(i), "rb") as f_data:
        tract_d = pkl.load(f_data)
    epoch_numbers = tract_d[0]
    sphere_num_linear_regions.append(tract_d[1])
    sphere_dists.append(tract_d[2])

sphere_linear_regions_std = np.std(sphere_num_linear_regions, axis=0)
sphere_linear_regions_mean = np.mean(sphere_num_linear_regions, axis=0)


num_linear_regions = np.array(num_linear_regions)
print(num_linear_regions.shape)
dists = np.array(dists)
linear_regions_trac_std = np.std(num_linear_regions, axis=0)
linear_regions_trac_mean = np.mean(num_linear_regions, axis=0)
plt.plot(epoch_numbers, linear_regions_trac_mean)
plt.plot(epoch_numbers, sphere_linear_regions_mean)
plt.xlabel("Epoch Number")
plt.ylabel("Number of Linear Regions")
plt.grid()
plt.fill_between(epoch_numbers, linear_regions_trac_mean - linear_regions_trac_std, linear_regions_trac_mean + linear_regions_trac_std, color='blue', alpha=0.2)
plt.fill_between(epoch_numbers, sphere_linear_regions_mean - sphere_linear_regions_std, sphere_linear_regions_mean + sphere_linear_regions_std, color='green', alpha=0.2)
plt.savefig("num_linear_sl.png")

plt.clf()

sphere_dists_std = np.std(sphere_dists, axis=0)
sphere_dists_mean = np.mean(sphere_dists, axis=0)
dists_trac_std = np.std(dists, axis=0)
dists_trac_mean = np.mean(dists, axis=0)
plt.plot(epoch_numbers, dists_trac_mean)
plt.plot(epoch_numbers, sphere_dists_mean)
plt.xlabel("Epoch Number")
plt.ylabel("Distance to nearest boundary")
plt.grid()
plt.fill_between(epoch_numbers, dists_trac_mean - dists_trac_std, dists_trac_mean + dists_trac_std, color='blue', alpha=0.2)
plt.fill_between(epoch_numbers, sphere_dists_mean - sphere_dists_std, sphere_dists_mean + sphere_dists_std, color='green', alpha=0.2)
plt.savefig("dists_sl.png")
