import numpy as np
from datetime import datetime
from scipy.optimize import minimize
import pickle

def acq_max(ac, M, random_features, w_sample, random_features_var, w_sample_var, bounds, omega):
    log_file_name = "obj_funcs/synth_func.pkl"
    all_func_info = pickle.load(open(log_file_name, "rb"))
    domain = all_func_info["domain"]

    para_dict={"M":M, "random_features":random_features, "w_sample":w_sample, \
               "random_features_var":random_features_var, "w_sample_var":w_sample_var, "omega":omega}
    ys = []
    for i, x in enumerate(domain):
        ys.append(-ac(x.reshape(1, -1), para_dict))

    ys = np.squeeze(np.array(ys))
    argmin_ind = np.argmin(ys)
    x_max = domain[argmin_ind, :]

    return x_max


class UtilityFunction(object):
    def __init__(self, kind):
        self.kind = kind

    def utility(self, x, para_dict):
        M, random_features, w_sample, random_features_var, w_sample_var, omega = \
                para_dict["M"], para_dict["random_features"], para_dict["w_sample"], \
                para_dict["random_features_var"], para_dict["w_sample_var"], para_dict["omega"]

        if self.kind == 'ts':
            return self._ts(x, M, random_features, w_sample)
        elif self.kind == 'ts_mean_var':
            return self._ts_mean_var(x, M, random_features, w_sample, random_features_var, w_sample_var, omega)

    @staticmethod
    def _ts(x, M, random_features, w_sample):
        d = x.shape[1]
        
        s = random_features["s"]
        b = random_features["b"]
        obs_noise = random_features["obs_noise"]
        v_kernel = random_features["v_kernel"]

        x = np.squeeze(x).reshape(1, -1)
        features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
        features = features.reshape(-1, 1)

        features = features / np.sqrt(np.inner(np.squeeze(features), np.squeeze(features)))
        features = np.sqrt(v_kernel) * features # v_kernel is set to be 1 here in the synthetic experiments

        f_value = np.squeeze(np.dot(w_sample, features))

        return f_value

    @staticmethod
    def _ts_mean_var(x, M, random_features, w_sample, random_features_var, w_sample_var, omega):
        d = x.shape[1]

        # get func value for mean
        s = random_features["s"]
        b = random_features["b"]
        obs_noise = random_features["obs_noise"]
        v_kernel = random_features["v_kernel"]

        x = np.squeeze(x).reshape(1, -1)
        features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
        features = features.reshape(-1, 1)

        features = features / np.sqrt(np.inner(np.squeeze(features), np.squeeze(features)))
        features = np.sqrt(v_kernel) * features # v_kernel is set to be 1 here in the synthetic experiments

        f_value = np.squeeze(np.dot(w_sample, features))

        
        # get func value for var
        s = random_features_var["s"]
        b = random_features_var["b"]
        obs_noise = random_features_var["obs_noise"]
        v_kernel = random_features_var["v_kernel"]

        x = np.squeeze(x).reshape(1, -1)
        features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
        features = features.reshape(-1, 1)

        features = features / np.sqrt(np.inner(np.squeeze(features), np.squeeze(features)))
        features = np.sqrt(v_kernel) * features # v_kernel is set to be 1 here in the synthetic experiments

        f_value_var = np.squeeze(np.dot(w_sample_var, features))


        return f_value * omega + f_value_var * (1-omega)
    
