# Run command: python3 main_demo_2d.py

import numpy as np
import sklearn
from sklearn.gaussian_process.kernels import RBF
import matplotlib.pyplot as plt
import time

from function import f_demo_2d
from algorithms import GP_UCB

## An example in higher dimension ##


# Set inputs
xD = np.linspace(0, 1.5*np.pi, 100)
yD = xD
mu_prior = 0
sigma_prior = 1

# Sigma for the noise epsilon ~ N(0, sigma**2)
sigma_noise = 0.5

# Set the kernel
kernel = 1.0 * RBF(0.3, (1e-2, 1e2))

nstep = 20
n_iter = 12*nstep
beta = 50
gp_ucb = GP_UCB(f_demo_2d, kernel=kernel, mu=mu_prior, sigma=sigma_prior, sigma_noise=sigma_noise, beta=beta, coords=(xD, yD))

# Find the solution xstar
outputs = [(x, gp_ucb.f(x)) for x in gp_ucb.input_space]
xstar, f_xstar = max(outputs, key=lambda item:item[1])

# Run algo
cpu_time = time.time()
print('\n niter: ', n_iter)
for i in range(n_iter):
    if i%nstep == 0:
        print('\n ########## Iteration: ', i, ' ##########')

    # \epsilon ~ N(0, sigma_noise^2)
    epsilon = np.random.rand()*sigma_noise
    
    # x_n = argmax mu + beta*sigma
    gp_ucb.argmax_ucb()
    
    # y_n = f(x_n) + epsilon
    gp_ucb.sample_y(epsilon)
    
    # Bayesian update on mu and sigma
    gp_ucb.bayesian_update()
    
    # Display summary
    #gp_ucb.summary()
    
    # Save data for plot
    gp_ucb.save_plot()

print('\n UCB-algorithm terminated')

plt.show()

print('\n Elapsed time: ', time.time() - cpu_time)
print('\n xstar: ', xstar, '\t f_xstar: ', f_xstar)
print('\n x_n: ', gp_ucb.x_n, '\t f(x_n): ', gp_ucb.f(gp_ucb.x_n))


# 3D Plot

from mpl_toolkits.mplot3d import Axes3D

self = gp_ucb
i = n_iter-1
fig = plt.figure()
ax = Axes3D(fig)
f_output = np.array([self.f(e) for e in self.input_space]).reshape(*self.meshgrid.shape[1:])
ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1],
                  f_output,
                  alpha=0.5, color='b')
ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1], 
                  self.mu_plot[i].reshape(self.meshgrid[0].shape),
                  alpha=0.5, color='g')
ax.scatter([x[0] for x in self.X[:i+1]], [x[1] for x in self.X[:i+1]], self.Y[:i+1],
           c='r', marker='o', alpha=1.0)

plt.show()
