import pickle
import numpy as np

import matplotlib.pyplot as plt
import numpy as np
import matplotlib
font = {'weight' : 'bold',
        'size'   : 16}
matplotlib.rc('font', **font)

with open('loss_curve.pickle', 'rb') as handle:
    record = pickle.load(handle)
with open('x1.pickle', 'rb') as handle:
    x1 = pickle.load(handle)
with open('x2.pickle', 'rb') as handle:
    x2 = pickle.load(handle)


def moving_average(a, n=3) :
    if n == 1:
        return a

    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return  np.insert(ret[n - 1:] / n, 0, a[0])


line_width = 4
percentile = 95
yy = moving_average(np.mean(record["normal"],axis=0), n=1 )
yy2 = moving_average(np.mean(record["sub-sampling"],axis=0), n=1)
yy3 = moving_average(np.mean(record["mini-batch"],axis=0), n=1 )

plt.plot(x1[:len(yy)], yy , c="r",linewidth=line_width,  label="SGD")
upper_loss = np.percentile(record["normal"] , percentile, axis=0)
lower_loss = np.percentile(record["normal"] , 100 - percentile, axis=0)
plt.fill_between(x1[:len(lower_loss)],  lower_loss, upper_loss, color="r", alpha=0.3 )

plt.plot(x2[:len(yy2)], yy2 , c="b", linewidth=line_width, label="SGD with subsampling")
upper_loss = np.percentile(record["sub-sampling"] , percentile, axis=0)
lower_loss = np.percentile(record["sub-sampling"] , 100 - percentile, axis=0)
plt.fill_between(x2[:len(lower_loss)],  lower_loss, upper_loss, color="b", alpha=0.3 )


plt.plot(x2[:len(yy3)], yy3 , c="g", linewidth=line_width,   label="mini-batch SGD")
upper_loss = np.percentile(record["mini-batch"] , percentile, axis=0)
lower_loss = np.percentile(record["mini-batch"] , 100 - percentile, axis=0)
plt.fill_between(x2[:len(lower_loss)],  lower_loss, upper_loss, color="g", alpha=0.3 )

#plt.plot(x1[:len(yy):10000], yy[::10000], c="r", label="normal")
#plt.plot(x2[:len(yy2):100], yy2[::100], c="b", label="sub-sampling")
#plt.plot(x2[:len(yy3):100], yy3[::100], c="g",   label="mini-batch")
# plt.ticklabel_format(axis="x", style="sci", scilimits=(0,4))
labels = [0, r"$2{\times}10^4$", r"$4{\times}10^4$", r"$6{\times}10^4$", r"$8{\times}10^4$", r"$10{\times}10^4$"]
plt.xticks([0, 20000, 40000, 60000, 80000, 100000], labels )
plt.xlim(-100, 100000)
plt.ylim(-0.005,0.25)
plt.legend()
plt.tight_layout()
plt.xlabel("Number of Observations", weight="bold")
plt.ylabel("Loss", weight="bold")
plt.savefig("loss_curve.png", dpi=300)
plt.show()