import numpy as np
import GPy
from helper_funcs_bts_red import UtilityFunction, acq_max
import pickle
import itertools
import time

class BTS_RED(object):
    def __init__(self, f, pbounds, gp_opt_schedule, ARD=False, \
                 log_file=None, M_TS=100, verbose=1, n_min=3, n_max=50, noise_var_func=None, domain=None, \
                 batch_size=20, R2=0.02, beta_t=None, \
                fix_nt_flag=False, fix_nt_value=3, \
                use_init=False, save_init=False, save_init_file=None, \
                T=50, fix_nt_init=10):
        """
        """
        
        self.fix_nt_init = fix_nt_init
        
        self.T = T
        
        self.use_init = use_init
        self.save_init = save_init
        self.save_init_file = save_init_file

        self.fix_nt_flag = fix_nt_flag
        self.fix_nt_value = fix_nt_value

        self.beta_t = beta_t
        self.n_min = n_min
        self.n_max = n_max
        self.noise_var_func = noise_var_func
        self.domain = domain
        self.batch_size = batch_size
        self.R2 = R2

        self.M_TS = M_TS
        self.ARD = ARD
        self.log_file = log_file        
        self.pbounds = pbounds
        self.incumbent = None
        
        self.keys = list(pbounds.keys())
        self.dim = len(pbounds)

        self.bounds = []
        for key in self.pbounds.keys():
            self.bounds.append(self.pbounds[key])
        self.bounds = np.asarray(self.bounds)
        
        self.f = f

        self.initialized = False

        self.init_points = []
        self.x_init = []
        self.y_init = []

        self.X = np.array([]).reshape(-1, 1)
        self.Y = np.array([])
        
        self.i = 0

        self.gp = None
        self.gp_opt_schedule = gp_opt_schedule

        self.util = None
        
        self.res = {}
        self.res['max'] = {'max_val': None,
                           'max_params': None}
        self.res['all'] = {'values':[], 'params':[], 'init_values':[], 'init_params':[], 'init':[], \
                          'f_values':[], 'init_f_values':[], 'noise_var_values':[], 'init_noise_var_values':[], \
                          'n_t':[], 'init_n_t':[], \
                          'values_batch':[], 'f_values_batch':[], 'noise_var_values_batch':[], \
                          'incumbent_x':[]}

        self.verbose = verbose


    def init(self, init_points, existing_init_x):
        
        if existing_init_x is None:
            l = [np.random.uniform(x[0], x[1], size=init_points) for x in self.bounds]
            self.init_points += list(map(list, zip(*l)))
        else:
            self.init_points = existing_init_x

        y_init = []
        for x in self.init_points:
            n_t = self.fix_nt_init

            y, f_value, noise_var_value = self.f(x, n_t)

            y_init.append(y)
            self.res['all']['init_values'].append(y)
            self.res['all']['init_f_values'].append(f_value)
            self.res['all']['init_noise_var_values'].append(noise_var_value)
            self.res['all']['init_n_t'].append(n_t)
            
            self.res['all']['init_params'].append(dict(zip(self.keys, x)))

        self.X = np.asarray(self.init_points)
        self.Y = np.asarray(y_init)

        self.incumbent = np.max(y_init)
        self.initialized = True

        init = {"X":self.X, "Y":self.Y}
        self.res['all']['init'] = init

    def maximize(self, n_iter=25, init_points=5):
        self.util_ts = UtilityFunction()

        if not self.initialized:
            if self.use_init != None:
                init = pickle.load(open(self.use_init, "rb"))

                self.X = init["X"]
                self.init(init_points, self.X)
            else:
                self.init(init_points, None)

        self.gp = GPy.models.GPRegression(self.X, self.Y.reshape(-1, 1), \
                GPy.kern.RBF(input_dim=self.X.shape[1], lengthscale=1.0, variance=1e-4, ARD=self.ARD))


        # optimize GP hypers
        if init_points > 1:
            self.gp.optimize_restarts(num_restarts = 10, messages=False)
            print("---Optimized hyper: ", self.gp)

        nt_advance = 0
        query_advance = None

        selected_batch = []
        n_t_sum = 0
        while n_t_sum < self.batch_size - nt_advance:
            M_target = self.M_TS

            ls_target = self.gp["rbf.lengthscale"][0]
            v_kernel = self.gp["rbf.variance"][0]

            obs_noise = (1 + 2/self.T) * (self.beta_t[self.X.shape[0] - init_points + 1]**2) # test

            try:
                s = np.random.multivariate_normal(np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim), M_target)
            except:
                s = np.random.rand(M_target, self.dim) - 0.5

            b = np.random.uniform(0, 2 * np.pi, M_target)

            random_features_target = {"M":M_target, "length_scale":ls_target, "s":s, "b":b, "obs_noise":obs_noise, "v_kernel":v_kernel}
            Phi = np.zeros((self.X.shape[0], M_target))
            for i, x in enumerate(self.X):
                x = np.squeeze(x).reshape(1, -1)
                features = np.sqrt(2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)

                features = features / np.sqrt(np.inner(features, features))
                features = np.sqrt(v_kernel) * features

                features = features * self.beta_t[self.X.shape[0] - init_points + 1] # test
                
                Phi[i, :] = features

            Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
            Sigma_t_inv = np.linalg.inv(Sigma_t)
            nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), self.Y.reshape(-1, 1))

            try:
                w_sample = np.random.multivariate_normal(np.squeeze(nu_t), obs_noise * Sigma_t_inv, 1)
            except:
                w_sample = np.random.rand(1, M_target) - 0.5

            x_max = acq_max(ac=self.util_ts.utility, M=M_target, random_features=random_features_target, \
                            w_sample=w_sample, bounds=self.bounds)

            if not self.fix_nt_flag:
                # choose n_t
                ind = np.argmin(np.abs(self.domain - x_max))
                n_t = int(np.ceil(self.noise_var_func[ind] / self.R2))

                if (len(self.X) - len(self.init_points)) <= 0.5 * n_iter:
                    self.n_max = int(self.batch_size / 2)
                else:
                    self.n_max = self.batch_size

                n_t = np.clip(n_t, self.n_min, self.n_max)
            else:
                n_t = self.fix_nt_value

            n_t_sum += n_t

            
            selected_batch.append([x_max, n_t])
        
        nt_advance = n_t_sum - (self.batch_size - nt_advance)
#         print("nt_advance: ", nt_advance)
        
        
        if self.fix_nt_flag: # we don't do the advance thing if we want to fix n_t
            nt_advance = 0
            selected_batch = selected_batch[:-1]

            
        if nt_advance > 0:
            query_advance = selected_batch[-1]
            selected_batch = selected_batch[:-1]
        
        print("batch size: ", len(selected_batch))
        
        for i in range(n_iter):
            values_batch, f_values_batch, noise_var_values_batch = [], [], []
            for i in range(len(selected_batch)):
                x_max = selected_batch[i][0]
                n_t = selected_batch[i][1]
                
                print("n_t: ", n_t)
                
                y, f_value, noise_var_value = self.f(x_max, n_t)

                self.res['all']['f_values'].append(f_value)
                self.res['all']['noise_var_values'].append(noise_var_value)
                self.res['all']['n_t'].append(n_t)

                self.Y = np.append(self.Y, y)
                self.X = np.vstack((self.X, x_max.reshape((1, -1))))

                self.res['all']['values'].append(self.Y[-1])
                self.res['all']['params'].append(self.X[-1])
                
                values_batch.append(y)
                f_values_batch.append(f_value)
                noise_var_values_batch.append(noise_var_value)

            self.res['all']['values_batch'].append(np.max(values_batch))
            self.res['all']['f_values_batch'].append(np.max(f_values_batch))
            self.res['all']['noise_var_values_batch'].append(np.min(noise_var_values_batch))

            incumbent_x = self.X[np.argmax(self.Y)]
            self.res['all']['incumbent_x'].append(incumbent_x)


            self.gp.set_XY(X=self.X, Y=self.Y.reshape(-1, 1))

            if len(self.X) >= self.gp_opt_schedule and len(self.X) % self.gp_opt_schedule == 0:
                self.gp.optimize_restarts(num_restarts = 10, messages=False)
                print("---Optimized hyper: ", self.gp)



            selected_batch = []
            if query_advance is not None:
                selected_batch.append(query_advance)
                query_advance = None
                
            n_t_sum = 0
            while n_t_sum < self.batch_size - nt_advance:
                M_target = self.M_TS
                
                ls_target = self.gp["rbf.lengthscale"][0]
                v_kernel = self.gp["rbf.variance"][0]
                
                obs_noise = (1 + 2/self.T) * (self.beta_t[self.X.shape[0] - init_points + 1]**2) # test

                try:
                    s = np.random.multivariate_normal(np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim), M_target)
                except:
                    s = np.random.rand(M_target, self.dim) - 0.5
                b = np.random.uniform(0, 2 * np.pi, M_target)

                random_features_target = {"M":M_target, "length_scale":ls_target, "s":s, "b":b, "obs_noise":obs_noise, "v_kernel":v_kernel}

                Phi = np.zeros((self.X.shape[0], M_target))
                for i, x in enumerate(self.X):
                    x = np.squeeze(x).reshape(1, -1)
                    features = np.sqrt(2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)

                    features = features / np.sqrt(np.inner(features, features))
                    features = np.sqrt(v_kernel) * features

                    features = features * self.beta_t[self.X.shape[0] - init_points + 1] # test

                    Phi[i, :] = features

                Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
                Sigma_t_inv = np.linalg.inv(Sigma_t)
                nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), self.Y.reshape(-1, 1))

                try:
                    w_sample = np.random.multivariate_normal(np.squeeze(nu_t), obs_noise * Sigma_t_inv, 1)
                except:
                    w_sample = np.random.rand(1, M_target) - 0.5
                x_max = acq_max(ac=self.util_ts.utility, M=M_target, random_features=random_features_target, \
                                w_sample=w_sample, bounds=self.bounds)
                
                if not self.fix_nt_flag:
                    # choose n_t
                    ind = np.argmin(np.abs(self.domain - x_max))
                    n_t = int(np.ceil(self.noise_var_func[ind] / self.R2))
                    
                    if (len(self.X) - len(self.init_points)) <= 0.5 * n_iter:
                        self.n_max = int(self.batch_size / 2)
                    else:
                        self.n_max = self.batch_size
                    
                    n_t = np.clip(n_t, self.n_min, self.n_max)
                else:
                    n_t = self.fix_nt_value
                
                
                n_t_sum += n_t


                selected_batch.append([x_max, n_t])

            nt_advance = n_t_sum - (self.batch_size - nt_advance)
#             print("nt_advance: ", nt_advance)

        
            if self.fix_nt_flag: # we don't do the advance thing if we want to fix n_t
                nt_advance = 0
                selected_batch = selected_batch[:-1]
        
        
            if nt_advance > 0:
                query_advance = selected_batch[-1]
                selected_batch = selected_batch[:-1]

            print("batch size: ", len(selected_batch))
            
            print("iter {0} ------ x_t: {1}, y_t: {2}".format(self.i+1, x_max, y))

            self.i += 1

            if self.log_file is not None:
                pickle.dump(self.res, open(self.log_file, "wb"))

