import numpy as np
from MetaLearnBO.functions import ContinuousObjFcn



import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

import numpy as np
import pandas as pd
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic,
                                              ExpSineSquared, DotProduct,
                                              ConstantKernel)
from matplotlib import pyplot as plt


class Simulation(ContinuousObjFcn):
    def __init__(self,domain,allRats, ratNum=0,seed=None):
        super(Simulation, self).__init__(domain)
        self.dim = len(domain[0])


        cRat=allRats[ratNum]
        x = np.atleast_2d(np.linspace(0, 4, 100)).T
        y_pred=cRat.sample_y(x,n_samples=2,random_state=seed)
        y_pred=np.array(y_pred)
        y_avg=np.mean(y_pred,axis=1)


        alphas=[.55,.09,.05,.05,.2,.35]
        kernel = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                       length_scale_bounds=(0.1, 10.0),
                                       periodicity_bounds=(1.0, 10.0))
        gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9, alpha=alphas[ratNum],random_state=3)
        gp.fit(x, y_avg)


        self.gp=gp
        self.x_groundTruth=x
        self.y_groundTruth=y_avg

        ###Best Params
        x = np.atleast_2d(np.linspace(0, 4, 500)).T
        y_pred,sigma=gp.predict(x, return_std=True)
        obj=np.multiply(x.reshape(500,),y_pred)
        i=np.argmax(obj)
        self.bestP=x[i]
        print("BESTP",self.bestP)

        """plt.figure()
        plt.plot(x, np.multiply(x.reshape(500,), y_pred))
        plt.title('Objective')

        plt.figure()
        plt.plot(x, y_pred, 'b-', label='Prediction')

        plt.fill(np.concatenate([x, x[::-1]]),
                 np.concatenate([y_pred - 1.9600 * sigma,
                                 (y_pred + 1.9600 * sigma)[::-1]]),
                 alpha=.5, fc='b', ec='None', label='95% confidence interval', )
        plt.show()
        """










    def get_RatModel(self):
        return self.gp

    def __call__(self, x):
      out = self.gp.sample_y(x.reshape(-1,1))

      #print("SAMPLED!",x,np.multiply(x,out).squeeze())

      x2 = np.atleast_2d(np.linspace(0, 4, 500)).T
      y_pred, sigma = self.gp.predict(x2, return_std=True)
      """plt.figure()
      plt.plot(x2, np.multiply(x2.reshape(500, ), y_pred))
      plt.scatter(x, np.multiply(x.reshape(-1,),out.reshape(-1,)))
      plt.title('Objective')

      plt.figure()
      plt.plot(x2, y_pred, 'b-', label='Prediction')

      plt.fill(np.concatenate([x2, x2[::-1]]),
               np.concatenate([y_pred - 1.9600 * sigma,
                               (y_pred + 1.9600 * sigma)[::-1]]),
               alpha=.5, fc='b', ec='None', label='95% confidence interval', )

      plt.show()
      """



      return np.multiply(x.reshape(-1,),out.reshape(-1,)).squeeze()

    def predictRat(self,action):
        out=self.gp.predict(action)

        return out



    def updateState(self,state):
        self.currentState=state

class CanonicalRats:
    def __init__(self):
        #file = 'C:/Users/mschr/CORELABProjects/RNS-Meta/Data/Data_Memory.csv'
        file = 'E:\ResearchGatech\RNS-Meta\Data\Data_Memory.csv'
        data = pd.read_csv(
            file, sep=',')



        X = data['amplitude']
        X = np.array(X)
        y = data['discriminant_score_b']
        y = np.array(y)

        X1 = X[0:25].reshape(-1, 1)  # .45
        y1 = y[0:25].reshape(-1)
        kernel1 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat1 = GaussianProcessRegressor(kernel=kernel1, n_restarts_optimizer=9, alpha=.45,random_state=2)
        self.rat1.fit(X1, y1)

        X2 = X[26:50].reshape(-1, 1)  # .09
        y2 = y[26:50].reshape(-1)
        kernel2 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat2 = GaussianProcessRegressor(kernel=kernel2, n_restarts_optimizer=9, alpha=.09,random_state=3)
        self.rat2.fit(X2, y2)

        X3 = X[51:74].reshape(-1, 1)  # .05
        y3 = y[51:74].reshape(-1)
        kernel3 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat3 = GaussianProcessRegressor(kernel=kernel3, n_restarts_optimizer=9, alpha=.05,random_state=3)
        self.rat3.fit(X3, y3)

        X4 = X[75:121].reshape(-1, 1)  # .05
        y4 = y[75:121].reshape(-1)
        kernel4 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat4 = GaussianProcessRegressor(kernel=kernel4, n_restarts_optimizer=9, alpha=.05,random_state=3)
        self.rat4.fit(X4, y4)

        X5 = X[122:148].reshape(-1, 1)  # .2
        y5 = y[122:148].reshape(-1)
        kernel5 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat5 = GaussianProcessRegressor(kernel=kernel5, n_restarts_optimizer=9, alpha=.2,random_state=3)
        self.rat5.fit(X5, y5)

        X6 = X[150:174].reshape(-1, 1)  # .35
        y6 = y[150:174].reshape(-1)
        kernel6 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                      length_scale_bounds=(0.1, 10.0),
                                      periodicity_bounds=(1.0, 10.0))
        self.rat6 = GaussianProcessRegressor(kernel=kernel6, n_restarts_optimizer=9, alpha=.35,random_state=6)
        self.rat6.fit(X6, y6)


    def getCanonicalRat(self):
        return [self.rat1,self.rat2,self.rat3,self.rat4,self.rat5,self.rat6]


# Synthetic objective funtion for testing the continuous case
class SynthObjFcn(ContinuousObjFcn):
  def __init__(self, domain, rnd_seed, n_features, bfunc=np.cos, 
               noise_sigma=0.1, wiggliness=1.):
    # rnd_seed: the seed of randomness. Each rnd_seed corresponds to one deterministic function
    # bfunc: basis function form
    # wiggliness: the function is more wiggly if it's higher
    super(SynthObjFcn,self).__init__(domain)
    self.rnd_stream = np.random.RandomState(rnd_seed)
    self.n_features = n_features
    self.Sigma = self.rnd_stream.rand(n_features,n_features*10)
    self.Sigma = np.dot(self.Sigma, self.Sigma.T)# + np.eye(n_features)*0.1
    self.mu = self.rnd_stream.rand(n_features)
    self.bfunc = bfunc
    self.noise_sigma = noise_sigma
    self.dim = len(domain[0])
    self.wiggliness = wiggliness
    self.sample_func()

  def phi(self, x):
    # x : N * d
    # phi: n_features * N
    phi = np.sqrt(2.0/self.n_features)*\
            self.bfunc(np.dot(self.b,x.T)+np.tile(self.c,( 1, x.shape[0] ) ) )
    return phi

  def sample_func(self):
    # sample basis func
    self.b = self.rnd_stream.randn(self.n_features,self.dim) * self.wiggliness
    self.c = self.rnd_stream.uniform(-np.pi,np.pi, (self.n_features,1))

    # sample coefficients
    self.w = self.rnd_stream.multivariate_normal(self.mu, self.Sigma)

  def __call__(self, x):
    if len(x.shape)==1:
      x = x[None,:]
    y= np.dot(self.phi(x).T,self.w) + self.rnd_stream.normal(0, self.noise_sigma, len(x))
    return np.array(y).squeeze()

def test():
  domain = [[-10.],[10.]]
  rnd_seed = 7
  n_features = 1000
  syn_fcn = SynthObjFcn(domain, rnd_seed, n_features, noise_sigma=0.01, wiggliness=2.)
  xvec = np.array([np.linspace(-10,10,500)]).T
  y = syn_fcn(xvec)
  import matplotlib.pyplot as plt
  plt.plot(xvec.T[0], y)
  plt.show()

def test2d():
  domain = [[-10., -10.],[10., 10.]]
  rnd_seed = 12
  n_features = 1000
  syn_fcn = SynthObjFcn(domain, rnd_seed, n_features, noise_sigma=0.01)
  x, y = np.meshgrid(np.linspace(-10,10,100), np.linspace(-10,10,100))
  xy = np.array([x.ravel(),y.ravel()]).T
  z = syn_fcn(xy)
  import matplotlib.pyplot as plt
  fig = plt.figure()
  from mpl_toolkits.mplot3d import Axes3D
  ax = Axes3D(fig)
  from matplotlib import cm

  surf = ax.plot_surface(x, y, z.reshape(100,100), cmap=cm.coolwarm, linewidth=0)
  plt.show()

