import pickle

import numpy as np

from generator.data_stream import *
from optimization.quadratic import *

dim = 3

G = MCMC(r=1, batch_size=5, dim=dim)
P = QuadraticProblem(G, learning_rate= 0.01)
A = P.A

batch_sizes = [1,  10, 100, 1000  ]
dataset_size = 1000000
num_trajectories = 30
# learning_rate = 0.0001

record = {}
for batch_size in batch_sizes:
    record[batch_size] = []

for _ in range(num_trajectories):
    G = MCMC(r=10, batch_size=dataset_size, dim=dim)
    dataset = G.pop().tolist()

    for batch_size in batch_sizes:
        learning_rate = 0.01 # * batch_size/np.sqrt(dataset_size)
        sgd = DataStream(dataset, batch_size=batch_size)
        P = QuadraticProblem(sgd, learning_rate=learning_rate, A = A)
        for i in range(dataset_size):
            P.step()
        record[batch_size].append(P.loss_hist)

print("Training completed.")
with open('batchsize.pickle', 'wb') as handle:
    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
#with open('batchsize.pickle', 'rb') as handle:
#    record = pickle.load(handle)

print("Data saved.")