import pickle

from generator.data_stream import *
from optimization.quadratic import *

dim = 3

G = MCMC(r=1, batch_size=5, dim=dim)
P = QuadraticProblem(G, learning_rate= 0.01)
A = P.A

batch_sizes = [1,  10, 100, 1000  ]
r_list = [0.25, 0.5, 1, 2, 4, 8, 16]
dataset_size = 1000000
num_trajectories = 30
# learning_rate = 0.0001

for r in r_list:
    print("Recording parameter r =", r)
    record = {}
    for batch_size in batch_sizes:
        record[batch_size] = []

    for _ in range(num_trajectories):
        G = MCMC(r=r, batch_size=dataset_size, dim=dim)
        dataset = G.pop().tolist()

        for batch_size in batch_sizes:
            learning_rate = 0.01 # * batch_size/np.sqrt(dataset_size)
            sgd = DataStream(dataset, batch_size=batch_size)
            P = QuadraticProblem(sgd, learning_rate=learning_rate, A = A)
            for i in range(dataset_size):
                P.step()
            record[batch_size].append(P.loss_hist)

    with open('record_r'+str(r)+'.pickle', 'wb') as handle:
        pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print("Data saved.")
    #with open('batchsize.pickle', 'rb') as handle:
#    record = pickle.load(handle)
print("Mission completed.")