import numpy as np

from generator.data_stream import *
from optimization.quadratic import *

dim = 3

G = MCMC(r=1, batch_size=5, dim=dim)
P = QuadraticProblem(G, learning_rate= 0.01)
A = P.A

# batch_sizes = [1,  10, 100, 1000  ]
# r_list = [0.25, 0.5, 1, 2, 4, 8, 16]

# \tau form 1 to 1000
dataset_size = 1000

# use Monte Caro to estimate the expectation \EE[ F(\theta) | s_0 = (0,0,0)]
# we set \theta = (0,0,0).
num_trajectories = 10000
# learning_rate = 0.0001

r_list = [0.25, 0.5, 1, 2]
batch_sizes = [1,  10, 100  ]
record = {}
for r in r_list:
    for batch_size in batch_sizes:
        record[(r, batch_size)] = []
        # get generator
        generator_list = {}
        for i in range(num_trajectories):
            G = MCMC(r=r, batch_size=batch_size, dim=dim)
            G.current_state =   np.ones(dim)/10.0
            generator_list[i] = G

        for tau in range(1, dataset_size//batch_size):
            mc_est = 0.0
            for i in range(num_trajectories):
                G = generator_list[i]
                data = G.pop()
                np.mean(data, axis=0)
                mc_est += data.dot(A).dot(data.transpose())[0,0]
            mc_est /= num_trajectories

            tmp = 0.0
            for _ in range(num_trajectories):
                data = np.random.uniform(0,1,dim)
                tmp += data.dot(A).dot(data.transpose())
            f_value = tmp/num_trajectories
            record[(r, batch_size)].append(np.abs(mc_est - f_value))
        #raise

import pickle
print("Training completed.")
with open('tau.pickle', 'wb') as handle:
    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)

import matplotlib.pyplot as plt
plt.plot(record[(0.25, 1)], label="r=0.25", c="y")
plt.plot(record[(0.5, 1)],label="r=0.5",  c="r")
plt.plot(record[(1, 1)],label="r=1", c="b")
plt.plot(record[(2, 1)],label="r=2",  c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()

plt.plot(record[(0.25, 10)], label="r=0.25", c="y")
plt.plot(record[(0.5, 10)],label="r=0.5",  c="r")
plt.plot(record[(1, 10)],label="r=1", c="b")
plt.plot(record[(2, 10)],label="r=2",  c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()


plt.plot(record[(2, 1)], label="batch_size=1", c="b")
plt.plot(record[(2, 10)], label="batch_size=10", c="r")
plt.plot(record[(2, 100)], label="batch_size=100", c="y")
plt.title("r=4")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()

plt.plot(record[(0.25, 100)], c="y")
plt.plot(record[(0.5, 100)], c="r")
plt.plot(record[(1, 100)], c="b")
plt.plot(record[(2, 100)], c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()
raise

dataset = G.pop()
print(dataset.shape)
raise


for r in r_list:
    print("Recording parameter r =", r)
    record = {}
    for batch_size in batch_sizes:
        record[batch_size] = []

    for _ in range(num_trajectories):
        G = MCMC(r=r, batch_size=dataset_size, dim=dim)
        G.current_state = np.zeros(dim)
        dataset = G.pop() # obtain a trajectory
        print(dataset.shape)
        raise

        for batch_size in batch_sizes:
            learning_rate = 0.01 # * batch_size/np.sqrt(dataset_size)
            sgd = DataStream(dataset, batch_size=batch_size)
            P = QuadraticProblem(sgd, learning_rate=learning_rate, A = A)
            for i in range(dataset_size):
                P.step()
            record[batch_size].append(P.loss_hist)
