import numpy as np
from sklearn.preprocessing import MinMaxScaler
import random
from RatBrainModel import RatEst
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic,
                                              ExpSineSquared, DotProduct,
                                              ConstantKernel)
from matplotlib import pyplot as plt

import pandas as pd

from scipy.stats import norm

from scipy.optimize import minimize

#from bayesian_optimization_util import plot_approximation, plot_acquisition


def propose_location(acquisition, X_sample, Y_sample, gpr, bounds, n_restarts=25):
    '''
    Proposes the next sampling point by optimizing the acquisition function.

    Args:
        acquisition: Acquisition function.
        X_sample: Sample locations (n x d).
        Y_sample: Sample values (n x 1).
        gpr: A GaussianProcessRegressor fitted to samples.

    Returns:
        Location of the acquisition function maximum.
    '''
    dim = X_sample.shape[1]
    min_val = 1
    min_x = None

    def min_obj(X):
        # Minimization objective is the negative acquisition function
        return -acquisition(X.reshape(-1, dim), X_sample, Y_sample, gpr)

    # Find the best optimum by starting from n_restart different random points.
    for x0 in np.random.uniform(bounds[:, 0], bounds[:, 1], size=(n_restarts, dim)):
        res = minimize(min_obj, x0=x0, bounds=bounds, method='L-BFGS-B')
        if res.fun < min_val:
            min_val = res.fun[0]
            min_x = res.x

    return min_x.reshape(-1, 1)


def expected_improvement(X, X_sample, Y_sample, gpr, xi=0.01):
    '''
    Computes the EI at points X based on existing samples X_sample
    and Y_sample using a Gaussian process surrogate model.

    Args:
        X: Points at which EI shall be computed (m x d).
        X_sample: Sample locations (n x d).
        Y_sample: Sample values (n x 1).
        gpr: A GaussianProcessRegressor fitted to samples.
        xi: Exploitation-exploration trade-off parameter.

    Returns:
        Expected improvements at points X.
    '''
    mu, sigma = gpr.predict(X, return_std=True)
    mu_sample = gpr.predict(X_sample)

    sigma = sigma.reshape(-1, 1)

    # Needed for noise-based model,
    # otherwise use np.max(Y_sample).
    # See also section 2.4 in [...]
    mu_sample_opt = np.max(mu_sample)

    with np.errstate(divide='warn'):
        imp = mu - mu_sample_opt - xi
        Z = imp / sigma
        ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
        ei[sigma == 0.0] = 0.0

    return ei





file='E:/ResearchGatech/RNS-Meta/Data/Data_Memory.csv'
data = pd.read_csv(
        file, sep=',')

X=data['amplitude']
X=np.array(X)
#X=X[0:25].reshape(-1,1) #.45
#X=X[26:50].reshape(-1,1) #.09
#X=X[51:74].reshape(-1,1) #.05
X=X[75:121].reshape(-1,1) #.05
#X=X[122:148].reshape(-1,1) #.2
#X=X[150:174].reshape(-1,1) #.35
y=data['discriminant_score_b']
y=np.array(y)
y=y[75:121].reshape(-1)

kernel =  1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                length_scale_bounds=(0.1, 10.0),
                                periodicity_bounds=(1.0, 10.0))
gpTrue = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9,alpha=.05)
gpTrue.fit(X, y)
gpEst = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9,alpha=.05)

x = np.atleast_2d(np.linspace(0, 4, 100)).T
y_pred, sigma = gpTrue.predict(x, return_std=True)
plt.plot(x, y_pred, 'b-', label='Prediction')

plt.fill(np.concatenate([x, x[::-1]]),
         np.concatenate([y_pred - 1.9600 * sigma,
                         (y_pred + 1.9600 * sigma)[::-1]]),
         alpha=.5, fc='b', ec='None', label='95% confidence interval', )
plt.show()


n_iter = 40

plt.figure(figsize=(12, n_iter * 3))
plt.subplots_adjust(hspace=0.4)

num_burnin=1
bounds = np.array([[0, 4.0]])
X_sample=np.atleast_2d(np.linspace(0, 4, num_burnin)).T
Y_sample=gpTrue.sample_y(X_sample)
Y_sample = Y_sample.reshape(-1, 1)
print('here',Y_sample.shape)
for i in range(n_iter):
    # Update Gaussian process with existing samples
    gpEst.fit(X_sample, Y_sample)

    # Obtain next sampling point from the acquisition function (expected_improvement)
    X_next = propose_location(expected_improvement, X_sample, Y_sample, gpEst, bounds)

    # Obtain next noisy sample from the objective function

    Y_next = gpTrue.sample_y(X_next.reshape(1,-1))
    print(X_next)
    Y_next=Y_next.reshape(-1,1)

    # Plot samples, surrogate function, noise-free objective and next sampling location
    #plt.subplot(n_iter, 2, 2 * i + 1)
    plt.figure()
    x = np.atleast_2d(np.linspace(0, 4, 100)).T
    y_pred, sigma = gpEst.predict(x, return_std=True)
    plt.plot(x, y_pred, 'b-', label='Prediction')

    plt.fill(np.concatenate([x, x[::-1]]),
             np.concatenate([y_pred - 1.9600 * sigma,
                             (y_pred + 1.9600 * sigma)[::-1]]),
             alpha=.5, fc='b', ec='None', label='95% confidence interval', )

    #plt.show()
    #plot_approximation(gp, X, Y, X_sample, Y_sample, X_next, show_legend=i == 0)
    plt.scatter(X_sample,Y_sample)
    plt.title(f'Iteration {i + 1}')
    #plt.ylim([-1,2])
    plt.show()

    #plt.subplot(n_iter, 2, 2 * i + 2)
    #plot_acquisition(X, expected_improvement(X, X_sample, Y_sample, gpr), X_next, show_legend=i == 0)

    # Add sample to previous samples
    X_sample = np.vstack((X_sample, X_next))
    Y_sample = np.vstack((Y_sample, Y_next))




"""
gp.fit(X, y)

x = np.atleast_2d(np.linspace(0,4 , 10)).T

y_pred, sigma = gp.predict(x, return_std=True)

###############Estimation Model
num_samples=5
X_samples=np.random.uniform(low=0.0, high=4.0, size=(num_samples,1))
Y_samples=gp.sample_y(X_samples)
print(x)
print(Y_samples)
"""