import pickle

import numpy as np

from generator.data_stream import *
from optimization.quadratic import *

dim = 3

G = MCMC(r=1, batch_size=5, dim=dim)
P = QuadraticProblem(G, learning_rate= 0.01)
A = P.A

batch_size = 100
dataset_size =1000000
num_trajectories = 500

r = 2
G = MCMC(r=r, batch_size=dataset_size, dim=dim)
dataset = G.pop().tolist()
opt_list = ["sub-sampling", "normal", "mini-batch"]
record = {}
for optimizer in opt_list:
    record[optimizer] = []

x1 = []
x2 = []
skip = 100
for _ in range(num_trajectories):
    G = MCMC(r=2, batch_size=dataset_size, dim=dim)
    dataset = G.pop().tolist()
    print(_)
    for optimizer in opt_list:
        learning_rate = 0.001
        if optimizer == "sub-sampling":
            sgd = DataStream(dataset, batch_size=1)
            P = QuadraticProblem(sgd, learning_rate=learning_rate, A = A)
            for i in range(dataset_size):
                if i % skip == 0:
                    if _ == 0:
                        x2.append(i)
                    P.step()
            if _ == 0:
                x2.append(dataset_size)
            record[optimizer].append(P.loss_hist)


        if optimizer == "normal":
            sgd = DataStream(dataset, batch_size=1)
            P = QuadraticProblem(sgd, learning_rate=learning_rate, A = A)
            for i in range(dataset_size):
                if _ == 0:
                    x1.append(i)
                P.step()
            if _ == 0:
                x1.append(dataset_size)
            record[optimizer].append(P.loss_hist)

        if optimizer == "mini-batch":
            sgd = DataStream(dataset, batch_size=batch_size)
            P = QuadraticProblem(sgd, learning_rate=learning_rate * batch_size ** 0.25, A = A)
            for i in range(dataset_size):
                P.step()
            record[optimizer].append(P.loss_hist)


with open('loss_curve.pickle', 'wb') as handle:
    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('x1.pickle', 'wb') as handle:
    pickle.dump(x1, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('x2.pickle', 'wb') as handle:
    pickle.dump(x2, handle, protocol=pickle.HIGHEST_PROTOCOL)


import matplotlib.pyplot as plt

plt.plot(x1, np.mean(record["normal"],axis=0), c="r", label="normal")
plt.plot(x2, np.mean(record["sub-sampling"],axis=0), c="b", label="sub-sampling")
plt.plot(x2, np.mean(record["mini-batch"],axis=0), c="g",   label="mini-batch")
# plt.xlim(-100, 20200)
# plt.ylim(-0.01,0.05)
plt.legend()
plt.show()