import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
    
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor


# Base class #
class UCB(object):
    def __init__(self, *args, **kwargs):
        r"""
         args: (f,)
         kwargs: (kernel, mu, sigma, sigma_noise, beta, coords)
         """

        self.f = args[0]
        self.beta = kwargs.get('beta')

        # Some setup
        self._init_gp(kwargs.get('kernel'), kwargs.get('sigma_noise'))
        self._set_input_space(*kwargs.get('coords'))
        self._set_mu_sigma(kwargs.get('mu'), kwargs.get('sigma'))

        # Some inits
        self.X = ()
        self.Y = ()
        self.mu_plot = ()
        self.sigma_plot = ()

    def _set_input_space(self, *coords):
        self.meshgrid = np.array(np.meshgrid(*coords))
        self.input_space = self.meshgrid.reshape(self.meshgrid.shape[0], -1).T
        self._nd = self.input_space.shape[0]

    def _set_mu_sigma(self, mu, sigma):
        self.mu = np.array([mu for _ in range(self._nd)])
        self.sigma = np.array([sigma for _ in range(self._nd)])

    def _init_gp(self, kernel, sigma_noise):
        self.gp = GaussianProcessRegressor(kernel, alpha=sigma_noise ** 2)

    @property
    def x_n(self):
        return self.X[-1]

    @property
    def y_n(self):
        return self.Y[-1]

    def summary(self):
        print('beta: ', self.beta)
        print('mu: {:.6f}'.format(np.linalg.norm(self.mu)), 'sigma: {:.6f}'.format(np.linalg.norm(self.sigma)), )
        print('x_n: ', self.x_n, '\t y_n: ', self.y_n)

    def save_plot(self):
        self.mu_plot += (self.mu,)
        self.sigma_plot += (self.sigma,)

    def plot(self, rows=1, step=1):

        # Plot routine
        def _plot(p, i):
            # Display the exact function (f), the mean of the GP, the confidence interval
            # as well as all the points selected until the i-th iteration.
            mu = self.mu_plot[i]
            sigma = self.sigma_plot[i]

            p.plot(self.input_space, self.f(self.meshgrid), 'r', lw=4, zorder=9)

            width = np.sqrt(sigma)
            p.fill_between(self.input_space[:, 0], mu - width, mu + width,
                             alpha=0.5, color='b')
            p.plot(self.input_space, mu, 'k', lw=3, zorder=9)
            p.scatter(self.X[:i+1], self.Y[:i+1], c='r', s=50, zorder=10, edgecolors=(0, 0, 0))

        # Plot
        n_plot = int(len(self.mu_plot)/step)
        cols = int(n_plot / rows)
        f, arr = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
        i = 0
        for r in range(rows):
            for c in range(cols):
                if rows == 1:
                    p = arr[c]
                else:
                    p = arr[r, c]
                
                _plot(p, i)
                p.set_title('Iteration: ' + str(i))

                i += step

    def plot_simple_regret(self, f_xstar):
        # Plot the simple regret: f(xstar) - f(xN)
        N = len(self.X)
        simple_regret = np.full((*f_xstar.shape, N), f_xstar)
        simple_regret -= np.array([self.f(x_n) for x_n in self.X])
        plt.figure()
        plt.loglog(simple_regret, 'm')
        plt.title('Loglog simple regret: f(xstar) - f(x_n)')
        plt.figure()
        plt.plot(simple_regret, 'r')
        plt.title('simple regret: f(xstar) - f(x_n)')


class GP_UCB(UCB):
    r"""
    for n in range(n_iter)

        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        # x_n = argmax mu + beta*sigma
        gp_ucb.argmax_ucb()
    
        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()
    """
    def __init__(self, *args, **kwargs):
        UCB.__init__(self, *args, **kwargs)

    def argmax_ucb(self, beta=None):
        if beta is None:
            beta = self.beta
        x_n = np.argmax(self.mu + self.sigma * beta)
        self.X += (self.input_space[x_n],)

    def sample_y(self, epsilon):
        y_n = self.f(self.x_n) + epsilon
        self.Y += (y_n,)

    def bayesian_update(self):
        self.gp = self.gp.fit(self.X, self.Y)
        self.mu, self.sigma = self.gp.predict(self.input_space, return_std=True)


class MVR(UCB):
    r"""
    for n in range(n_iter)

        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        # x_temp_n = argmax sigma
        mvr.argmax_sigma()

        # y_n = f(x_temp_n) + epsilon_n
        mvr.sample_y(epsilon)

        # Bayesian update on mu and sigma
        mvr.bayesian_update()

        # The last step should theoretically be done after the end of the loop (as well as the update of mu above)
        # However, it is needed here in order to accumulate the values for x_hat at each iteration

        # x_hat_n = argmax mu
        mvr.argmax_mean()
    """

    def __init__(self, *args, **kwargs):
        UCB.__init__(self, *args, **kwargs)
        self.X_temp = ()

    @property
    def x_temp(self):
        return self.X_temp[-1]

    def argmax_sigma(self):
        x_temp = np.argmax(self.sigma)
        self.X_temp += (self.input_space[x_temp],)

    def argmax_mean(self):
        x_hat_n = np.argmax(self.mu)
        self.X += (self.input_space[x_hat_n],)

    def sample_y(self, epsilon):
        y_n = self.f(self.x_temp) + epsilon
        self.Y += (y_n,)

    def bayesian_update(self):
        self.gp = self.gp.fit(self.X_temp, self.Y)
        self.mu, self.sigma = self.gp.predict(self.input_space, return_std=True)


class GP_PI(UCB):
    r"""
    for n in range(n_iter)

        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        gp_ucb.argmax_ucb()

        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()
    """
    def __init__(self, *args, **kwargs):
        UCB.__init__(self, *args, **kwargs)
        self.zeta = kwargs.get('zeta')

    def _find_mu_cross(self):
        find_x_n = lambda x: np.where(np.all(self.input_space == x, axis=1))[0].squeeze()
        return max(tuple(self.mu[find_x_n(x)] for x in self.X)) if len(self.X) > 0 else 0

    def argmax_ucb(self):
        mu_cross = self._find_mu_cross()
        x_n = np.argmax(norm.cdf((self.mu - mu_cross - self.zeta)/self.sigma))
        self.X += (self.input_space[x_n],)

    def sample_y(self, epsilon):
        y_n = self.f(self.x_n) + epsilon
        self.Y += (y_n,)

    def bayesian_update(self):
        self.gp = self.gp.fit(self.X, self.Y)
        self.mu, self.sigma = self.gp.predict(self.input_space, return_std=True)


class GP_EI(UCB):
    r"""
    for n in range(n_iter)

        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        gp_ucb.argmax_ucb()

        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()
    """
    def __init__(self, *args, **kwargs):
        UCB.__init__(self, *args, **kwargs)
        self.zeta = kwargs.get('zeta')

    def _find_mu_cross(self):
        find_x_n = lambda x: np.where(np.all(self.input_space == x, axis=1))[0].squeeze()
        return max(tuple(self.mu[find_x_n(x)] for x in self.X)) if len(self.X) > 0 else 0

    def argmax_ucb(self):
        mu_cross = self._find_mu_cross()
        d = self.mu - mu_cross - self.zeta
        
        # Find non-zero for sigma
        ind_nz = np.where(self.sigma!=0)[0]
        CDF = norm.cdf(d[ind_nz]/self.sigma[ind_nz])
        PDF = norm.pdf(d[ind_nz]/self.sigma[ind_nz])
        x_n = np.argmax(d[ind_nz]*CDF + self.sigma[ind_nz]*PDF)
        self.X += (self.input_space[x_n],)

    def sample_y(self, epsilon):
        y_n = self.f(self.x_n) + epsilon
        self.Y += (y_n,)

    def bayesian_update(self):
        self.gp = self.gp.fit(self.X, self.Y)
        self.mu, self.sigma = self.gp.predict(self.input_space, return_std=True)


def run_GP_UCB(func, n_iter, beta, ker, sigma_noise, plot=False, nstep=10, mu_prior=None, sigma_prior=None, coords=None):

    gp_ucb = GP_UCB(func, kernel=ker, mu=mu_prior, sigma=sigma_prior, sigma_noise=sigma_noise, beta=beta, coords=coords)

    # Find the solution xstar
    outputs = [(x, gp_ucb.f(x)) for i, x in enumerate(gp_ucb.input_space)]
    xstar, f_xstar = max(outputs, key=lambda item:item[1])

    # Run algo
    beta_n = []
    for i in range(n_iter):
        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        # x_n = argmax mu + beta*sigma
        beta_n += [gp_ucb.beta(i+1)]
        gp_ucb.argmax_ucb(beta_n[-1])

        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()

        # Save data for plot
        gp_ucb.save_plot()

    if plot:
        # Plot
        gp_ucb.plot(rows=3, step=nstep)

        # Plot simple regret
        gp_ucb.plot_simple_regret(f_xstar)

        plt.show()
    
    return gp_ucb, xstar, f_xstar


def run_MVR(func, n_iter, beta, ker, sigma_noise, plot=False, nstep=10, mu_prior=None, sigma_prior=None, coords=None):

    mvr = MVR(func, kernel=ker, mu=mu_prior, sigma=sigma_prior, sigma_noise=sigma_noise, coords=coords)

    # Find the solution xstar
    outputs = [(x, mvr.f(x)) for i, x in enumerate(mvr.input_space)]
    xstar, f_xstar = max(outputs, key=lambda item:item[1])

    # Run algo
    for i in range(n_iter):
        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        
        # x_temp_n = argmax sigma
        mvr.argmax_sigma()

        # y_n = f(x_temp_n) + epsilon_n
        mvr.sample_y(epsilon)

        # Bayesian update on mu and sigma
        mvr.bayesian_update()

        # The last step should theoretically be done after the end of the loop (as well as the update of mu above)
        # However, it is needed here in order to accumulate the values for x_hat at each iteration

        # x_hat_n = argmax mu
        mvr.argmax_mean()

        # Save data for plot
        mvr.save_plot()

    if plot:
        # Plot
        mvr.plot(rows=3, step=nstep)

        # Plot simple regret
        mvr.plot_simple_regret(f_xstar)

        plt.show()
    
    return mvr, xstar, f_xstar


def run_GP_PI(func, n_iter, zeta, ker, sigma_noise, plot=False, nstep=10, mu_prior=None, sigma_prior=None, coords=None):

    gp_ucb = GP_PI(func, kernel=ker, mu=mu_prior, sigma=sigma_prior, sigma_noise=sigma_noise, zeta=zeta, coords=coords)

    # Find the solution xstar
    outputs = [(x, gp_ucb.f(x)) for i, x in enumerate(gp_ucb.input_space)]
    xstar, f_xstar = max(outputs, key=lambda item:item[1])

    # Run algo
    beta_n = []
    for i in range(n_iter):
        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        gp_ucb.argmax_ucb()

        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()

        # Save data for plot
        gp_ucb.save_plot()

    if plot:
        # Plot
        gp_ucb.plot(rows=3, step=nstep)

        # Plot simple regret
        gp_ucb.plot_simple_regret(f_xstar)

        plt.show()
    
    return gp_ucb, xstar, f_xstar


def run_GP_EI(func, n_iter, zeta, ker, sigma_noise, plot=False, nstep=10, mu_prior=None, sigma_prior=None, coords=None):

    gp_ucb = GP_EI(func, kernel=ker, mu=mu_prior, sigma=sigma_prior, sigma_noise=sigma_noise, zeta=zeta, coords=coords)

    # Find the solution xstar
    outputs = [(x, gp_ucb.f(x)) for i, x in enumerate(gp_ucb.input_space)]
    xstar, f_xstar = max(outputs, key=lambda item:item[1])

    # Run algo
    beta_n = []
    for i in range(n_iter):
        # \epsilon ~ N(0, sigma_noise^2)
        epsilon = np.random.rand()*sigma_noise

        gp_ucb.argmax_ucb()

        # y_n = f(x_n) + epsilon
        gp_ucb.sample_y(epsilon)

        # Bayesian update on mu and sigma
        gp_ucb.bayesian_update()

        # Save data for plot
        gp_ucb.save_plot()

    if plot:
        # Plot
        gp_ucb.plot(rows=3, step=nstep)

        # Plot simple regret
        gp_ucb.plot_simple_regret(f_xstar)

        plt.show()
    
    return gp_ucb, xstar, f_xstar
