import numpy as np

from generator.data_stream import *
from optimization.quadratic import *

dim = 3

G = MCMC(r=1, batch_size=5, dim=dim)
P = QuadraticProblem(G, learning_rate= 0.01)
A = P.A

# batch_sizes = [1,  10, 100, 1000  ]
# r_list = [0.25, 0.5, 1, 2, 4, 8, 16]

# \tau form 1 to 1000
dataset_size = 1000

# use Monte Caro to estimate the expectation \EE[ F(\theta) | s_0 = (0,0,0)]
# we set \theta = (0,0,0).
num_trajectories = 10000
# learning_rate = 0.0001

r_list = [0.25, 0.5, 1, 2]
batch_sizes = [1,  10, 100  ]
record = {}
recordb = {}
for r in r_list:
    for num_skip in batch_sizes:
        record[(r, num_skip)] = []
        # get generator
        generator_list = {}
        for i in range(num_trajectories):
            G = MCMC(r=r, batch_size=1, dim=dim)
            G.current_state =   np.ones(dim)/10.0
            generator_list[i] = G

        for tau in range(1, dataset_size//num_skip):
            mc_est = 0.0
            for i in range(num_trajectories):
                G = generator_list[i]
                for _ in range(num_skip):
                    data = G.pop()
                mc_est += data.dot(A).dot(data.transpose())[0,0]
            mc_est /= num_trajectories

            tmp = 0.0
            for _ in range(num_trajectories):
                data = np.random.uniform(0,1,dim)
                tmp += data.dot(A).dot(data.transpose())
            f_value = tmp/num_trajectories
            record[(r, num_skip)].append(np.abs(mc_est - f_value))
        #raise
    for batch_size in batch_sizes:
        recordb[(r, batch_size)] = []
        # get generator
        generator_list = {}
        for i in range(num_trajectories):
            G = MCMC(r=r, batch_size=batch_size, dim=dim)
            G.current_state =   np.ones(dim)/10.0
            generator_list[i] = G

        for tau in range(1, dataset_size//batch_size):
            mc_est = 0.0
            for i in range(num_trajectories):
                G = generator_list[i]
                data = G.pop()
                np.mean(data, axis=0)
                mc_est += data.dot(A).dot(data.transpose())[0,0]
            mc_est /= num_trajectories

            tmp = 0.0
            for _ in range(num_trajectories):
                data = np.random.uniform(0,1,dim)
                tmp += data.dot(A).dot(data.transpose())
            f_value = tmp/num_trajectories
            recordb[(r, batch_size)].append(np.abs(mc_est - f_value))

import pickle
print("Training completed.")
with open('tau_skip.pickle', 'wb') as handle:
    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('tau_bs.pickle', 'wb') as handle:
    pickle.dump(recordb, handle, protocol=pickle.HIGHEST_PROTOCOL)

import matplotlib.pyplot as plt
plt.plot(record[(0.25, 1)], label="r=0.25", c="y")
plt.plot(record[(0.5, 1)],label="r=0.5",  c="r")
plt.plot(record[(1, 1)],label="r=1", c="b")
plt.plot(record[(2, 1)],label="r=2",  c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()

plt.plot(record[(0.25, 10)], label="r=0.25", c="y")
plt.plot(record[(0.5, 10)],label="r=0.5",  c="r")
plt.plot(record[(1, 10)],label="r=1", c="b")
plt.plot(record[(2, 10)],label="r=2",  c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()


plt.plot(record[(2, 1)], label="batch_size=1", c="b")
plt.plot(record[(2, 10)], label="batch_size=10", c="r")
plt.plot(record[(2, 100)], label="batch_size=100", c="y")
plt.title("r=4")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()

plt.plot(record[(0.25, 100)], c="y")
plt.plot(record[(0.5, 100)], c="r")
plt.plot(record[(1, 100)], c="b")
plt.plot(record[(2, 100)], c="g")
plt.legend()
plt.xlabel("tau")
plt.ylabel("bias")
plt.show()
