import numpy as np
import GPy
from helper_funcs_bts_red_mean_var import UtilityFunction, acq_max
import pickle
import itertools
import time

class BTS_RED_mean_var(object):
    def __init__(self, f, pbounds, gp_opt_schedule, ARD=False, \
                 log_file=None, M_TS=100, verbose=1, n_min=3, n_max=50, noise_var_func=None, domain=None, \
                 batch_size=20, R2=0.02, beta_t=None, \
                fix_nt_flag=False, fix_nt_value=3, \
                use_init=False, save_init=False, save_init_file=None, \
                T=50, beta_t_var=None, mean_var_obj=True, omega=0.5, \
                estimate_sigma_max=False, ratio=0.2, fix_nt_init=5):
        """
        """
        self.fix_nt_init = fix_nt_init

        self.estimate_sigma_max = estimate_sigma_max
        self.ratio = ratio

        self.mean_var_obj = mean_var_obj
        self.omega = omega

        self.T = T
        
        self.use_init = use_init
        self.save_init = save_init
        self.save_init_file = save_init_file

        self.fix_nt_flag = fix_nt_flag
        self.fix_nt_value = fix_nt_value

        self.beta_t = beta_t
        self.beta_t_var = beta_t_var
        
        self.n_min = n_min
        self.n_max = n_max
        self.noise_var_func = noise_var_func
        self.domain = domain
        self.batch_size = batch_size
        self.R2 = R2

        self.M_TS = M_TS
        self.ARD = ARD
        self.log_file = log_file        
        self.pbounds = pbounds
        self.incumbent = None
        
        self.keys = list(pbounds.keys())
        self.dim = len(pbounds)

        self.bounds = []
        for key in self.pbounds.keys():
            self.bounds.append(self.pbounds[key])
        self.bounds = np.asarray(self.bounds)
        
        self.f = f

        self.initialized = False

        self.init_points = []
        self.x_init = []
        self.y_init = []

        self.X = np.array([]).reshape(-1, 1)
        self.Y = np.array([])
        self.Y_var = np.array([])
        
        self.i = 0

        self.gp = None
        self.gp_opt_schedule = gp_opt_schedule
        
        self.gp_var = None

        self.util = None
        
        self.res = {}
        self.res['max'] = {'max_val': None,
                           'max_params': None}
        self.res['all'] = {'values':[], 'params':[], 'init_values':[], 'init_params':[], 'init':[], \
                          'f_values':[], 'init_f_values':[], 'noise_var_values':[], 'init_noise_var_values':[], \
                          'n_t':[], 'init_n_t':[], \
                          'values_batch':[], 'f_values_batch':[], 'noise_var_values_batch':[], \
                          'incumbent_x':[], 'values_var':[], 'init_values_var':[], \
                          'gp_var_pred':[], 'estimated_var_R2':[]}

        self.verbose = verbose

    
    def init(self, init_points, existing_init_x):
        if existing_init_x is None:
            l = [np.random.uniform(x[0], x[1], size=init_points) for x in self.bounds]
            self.init_points += list(map(list, zip(*l)))
        else:
            self.init_points = existing_init_x

        y_init = []
        y_init_var = []
        for x in self.init_points:
            n_t = self.fix_nt_init

            y, y_var, f_value, noise_var_value = self.f(x, n_t)

            y_init.append(y)
            y_init_var.append(y_var)

            self.res['all']['init_values'].append(y)
            self.res['all']['init_values_var'].append(y_var)
            self.res['all']['init_f_values'].append(f_value)
            self.res['all']['init_noise_var_values'].append(noise_var_value)
            self.res['all']['init_n_t'].append(n_t)
            
            self.res['all']['init_params'].append(dict(zip(self.keys, x)))

        self.X = np.asarray(self.init_points)
        self.Y = np.asarray(y_init)
        self.Y_var = np.asarray(y_init_var)

        self.incumbent = np.max(y_init)
        self.initialized = True

        init = {"X":self.X, "Y":self.Y, "Y_var":self.Y_var}
        self.res['all']['init'] = init

        
    def sample_func_mean(self):
        M_target = self.M_TS

        ls_target = self.gp["rbf.lengthscale"][0]
        v_kernel = self.gp["rbf.variance"][0]

        obs_noise = (1 + 2/self.T) * (self.beta_t[self.X.shape[0] - len(self.init_points) + 1]**2)

        try:
            s = np.random.multivariate_normal(np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim), M_target)
        except:
            s = np.random.rand(M_target, self.dim) - 0.5

        b = np.random.uniform(0, 2 * np.pi, M_target)

        random_features_target = {"M":M_target, "length_scale":ls_target, "s":s, "b":b, "obs_noise":obs_noise, "v_kernel":v_kernel}
        Phi = np.zeros((self.X.shape[0], M_target))
        for i, x in enumerate(self.X):
            x = np.squeeze(x).reshape(1, -1)
            features = np.sqrt(2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)

            features = features / np.sqrt(np.inner(features, features))
            features = np.sqrt(v_kernel) * features

            features = features * self.beta_t[self.X.shape[0] - len(self.init_points) + 1] # test

            Phi[i, :] = features

        Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
        Sigma_t_inv = np.linalg.inv(Sigma_t)
        nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), self.Y.reshape(-1, 1))

        try:
            w_sample = np.random.multivariate_normal(np.squeeze(nu_t), obs_noise * Sigma_t_inv, 1)
        except:
            w_sample = np.random.rand(1, M_target) - 0.5

        return random_features_target, w_sample


    def sample_func_var(self):
        M_target = self.M_TS

        ls_target = self.gp_var["rbf.lengthscale"][0]
        v_kernel = self.gp_var["rbf.variance"][0]

        obs_noise = (1 + 2/self.T) * (self.beta_t_var[self.X.shape[0] - len(self.init_points) + 1]**2) # test

        try:
            s = np.random.multivariate_normal(np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim), M_target)
        except:
            s = np.random.rand(M_target, self.dim) - 0.5

        b = np.random.uniform(0, 2 * np.pi, M_target)

        random_features_target = {"M":M_target, "length_scale":ls_target, "s":s, "b":b, "obs_noise":obs_noise, "v_kernel":v_kernel}
        Phi = np.zeros((self.X.shape[0], M_target))
        for i, x in enumerate(self.X):
            x = np.squeeze(x).reshape(1, -1)
            features = np.sqrt(2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)

            features = features / np.sqrt(np.inner(features, features))
            features = np.sqrt(v_kernel) * features

            features = features * self.beta_t[self.X.shape[0] - len(self.init_points) + 1] # test

            Phi[i, :] = features

        Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
        Sigma_t_inv = np.linalg.inv(Sigma_t)
        nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), self.Y_var.reshape(-1, 1))

        try:
            w_sample = np.random.multivariate_normal(np.squeeze(nu_t), obs_noise * Sigma_t_inv, 1)
        except:
            w_sample = np.random.rand(1, M_target) - 0.5

        return random_features_target, w_sample


    def maximize(self, n_iter=25, init_points=5):
        self.util_ts = UtilityFunction(kind='ts')
        self.util_ts_mean_var = UtilityFunction(kind='ts_mean_var')

        if not self.initialized:
            if self.use_init != None:
                init = pickle.load(open(self.use_init, "rb"))

                self.X = init["X"]
                self.init(init_points, self.X)
            else:
                self.init(init_points, None)

        if self.estimate_sigma_max:
            var_max_estimate = - np.min(self.Y_var)
            self.R2 = self.ratio * var_max_estimate
            self.res['all']['estimated_var_R2'] = [var_max_estimate, self.R2]
#             print(self.R2)

        self.gp = GPy.models.GPRegression(self.X, self.Y.reshape(-1, 1), \
                GPy.kern.RBF(input_dim=self.X.shape[1], lengthscale=1.0, variance=0.1, ARD=self.ARD))
        self.gp_var = GPy.models.GPRegression(self.X, self.Y_var.reshape(-1, 1), \
                GPy.kern.RBF(input_dim=self.X.shape[1], lengthscale=1.0, variance=0.1, ARD=self.ARD))

        if init_points > 1:
            self.gp.optimize_restarts(num_restarts = 10, messages=False)
            self.gp_var.optimize_restarts(num_restarts = 10, messages=False)
            print("---Optimized hyper: ", self.gp)

        nt_advance = 0
        query_advance = None

        selected_batch = []
        n_t_sum = 0
        while n_t_sum < self.batch_size - nt_advance:
            random_features_target, w_sample = self.sample_func_mean()
            if not self.mean_var_obj:
                x_max = acq_max(ac=self.util_ts.utility, M=self.M_TS, \
                                random_features=random_features_target, w_sample=w_sample, \
                                random_features_var=None, w_sample_var=None, \
                                bounds=self.bounds, omega=None)
            else:
                random_features_target_var, w_sample_var = self.sample_func_var()
                x_max = acq_max(ac=self.util_ts_mean_var.utility, M=self.M_TS, \
                                random_features=random_features_target, w_sample=w_sample, \
                                random_features_var=random_features_target_var, w_sample_var=w_sample_var, \
                                bounds=self.bounds, omega=self.omega)
            
            if not self.fix_nt_flag:
                # choose n_t
                mean, var = self.gp_var.predict(x_max.reshape(1, -1))
                std = np.sqrt(var)
                upper_bound_var = -mean + std * self.beta_t_var[self.X.shape[0] - init_points + 1]
                upper_bound_var = np.clip(upper_bound_var, 0, 1e10)
                
                n_t = int(np.ceil(upper_bound_var[0][0] / self.R2))
                n_t = np.clip(n_t, self.n_min, self.n_max)
                
            else:
                n_t = self.fix_nt_value

            n_t_sum += n_t

            
            selected_batch.append([x_max, n_t])


        nt_advance = n_t_sum - (self.batch_size - nt_advance)

        if self.fix_nt_flag:
            nt_advance = 0
            selected_batch = selected_batch[:-1]
                
        if nt_advance > 0:
            query_advance = selected_batch[-1]
            selected_batch = selected_batch[:-1]

        print("batch size: ", len(selected_batch))
        
        for i in range(n_iter):
            if len(selected_batch) > 0:
                values_batch, f_values_batch, noise_var_values_batch = [], [], []
                for j in range(len(selected_batch)):
                    x_max = selected_batch[j][0]
                    n_t = selected_batch[j][1]

                    print("n_t: ", n_t)

                    y, y_var, f_value, noise_var_value = self.f(x_max, n_t)

                    self.res['all']['f_values'].append(f_value)
                    self.res['all']['noise_var_values'].append(noise_var_value)
                    self.res['all']['n_t'].append(n_t)

                    self.Y = np.append(self.Y, y)
                    self.Y_var = np.append(self.Y_var, y_var)
                    self.X = np.vstack((self.X, x_max.reshape((1, -1))))


                    self.res['all']['values'].append(self.Y[-1])
                    self.res['all']['values_var'].append(self.Y_var[-1])
                    self.res['all']['params'].append(self.X[-1])

                    values_batch.append(y)
                    f_values_batch.append(f_value)
                    noise_var_values_batch.append(noise_var_value)

                self.res['all']['values_batch'].append(np.max(values_batch))
                self.res['all']['f_values_batch'].append(np.max(f_values_batch))
                self.res['all']['noise_var_values_batch'].append(np.min(noise_var_values_batch))

                if self.mean_var_obj:
                    incumbent_x = self.X[np.argmax(self.omega * self.Y + (1-self.omega) * self.Y_var)]
                else:
                    incumbent_x = self.X[np.argmax(self.Y)]

                self.res['all']['incumbent_x'].append(incumbent_x)


            self.gp.set_XY(X=self.X, Y=self.Y.reshape(-1, 1))
            self.gp_var.set_XY(X=self.X, Y=self.Y_var.reshape(-1, 1))


            if self.estimate_sigma_max:
                var_max_estimate = - np.min(self.Y_var)
                self.R2 = self.ratio * var_max_estimate
                self.res['all']['estimated_var_R2'] = [var_max_estimate, self.R2]
#                 print(self.R2)
            
            all_means = []
            all_stds = []
            for x in self.domain:
                mean, var = self.gp_var.predict(x.reshape(1, -1))
                std = np.sqrt(var)
                all_stds.append(std)
                all_means.append(mean)
            self.res['all']['gp_var_pred'].append([all_means, all_stds])
            

            if len(self.X) >= self.gp_opt_schedule and len(self.X) % self.gp_opt_schedule == 0:
                self.gp.optimize_restarts(num_restarts = 10, messages=False)
                self.gp_var.optimize_restarts(num_restarts = 10, messages=False)
                print("---Optimized hyper: ", self.gp)

            selected_batch = []
            if query_advance is not None:
                selected_batch.append(query_advance)
                query_advance = None

            n_t_sum = 0
            while n_t_sum < self.batch_size - nt_advance:
                random_features_target, w_sample = self.sample_func_mean()
                if not self.mean_var_obj:
                    x_max = acq_max(ac=self.util_ts.utility, M=self.M_TS, \
                                    random_features=random_features_target, w_sample=w_sample, \
                                    random_features_var=None, w_sample_var=None, \
                                    bounds=self.bounds, omega=None)
                else:
                    random_features_target_var, w_sample_var = self.sample_func_var()
                    x_max = acq_max(ac=self.util_ts_mean_var.utility, M=self.M_TS, \
                                    random_features=random_features_target, w_sample=w_sample, \
                                    random_features_var=random_features_target_var, w_sample_var=w_sample_var, \
                                    bounds=self.bounds, omega=self.omega)

                if not self.fix_nt_flag:
                    # choose n_t
                    mean, var = self.gp_var.predict(x_max.reshape(1, -1))
                    std = np.sqrt(var)
                    upper_bound_var = -mean + std * self.beta_t_var[self.X.shape[0] - init_points + 1]
                    print("-------noise upper bound: upper_bound_var {0}--------".format(upper_bound_var))
                    print("-------noise -mean / std (scaled): {0} / {1}--------".format(-mean, \
                                               std * self.beta_t_var[self.X.shape[0] - init_points + 1]))
                    upper_bound_var = np.clip(upper_bound_var, 0, 1e10)

                    n_t = int(np.ceil(upper_bound_var[0][0] / self.R2))
                    n_t = np.clip(n_t, self.n_min, self.n_max)
                else:
                    n_t = self.fix_nt_value
                
                n_t_sum += n_t

                selected_batch.append([x_max, n_t])
            
            nt_advance = n_t_sum - (self.batch_size - nt_advance)
#             print("nt_advance: ", nt_advance)

            if self.fix_nt_flag: #
                nt_advance = 0
                selected_batch = selected_batch[:-1]
                
            if nt_advance > 0:
                query_advance = selected_batch[-1]
                selected_batch = selected_batch[:-1]
            
            print("batch size: ", len(selected_batch))
            
            print("iter {0} ------ x_t: {1}, y_t: {2}".format(self.i+1, x_max, y))

            self.i += 1

            if self.log_file is not None:
                pickle.dump(self.res, open(self.log_file, "wb"))

