import GPy
from bayesian_optimization_bts_red_mean_var import BTS_RED_mean_var
import pickle
import numpy as np

np.random.seed(0)

max_iter = 70

ls, ls_noise_var = 0.04, 0.15
noise_var_min, noise_var_max = 1e-4, 0.2

log_file_name = "obj_funcs/synth_func_mean_var.pkl"
all_func_info = pickle.load(open(log_file_name, "rb"))
domain = all_func_info["domain"]
f = all_func_info["f"]
f_noise_var = all_func_info["f_noise_var"]

def synth_func(param, n_t):
    x = param[0]
    ind = np.argmin(np.abs(domain - x))
    samples = np.random.normal(f[ind], np.sqrt(f_noise_var[ind]), n_t)

    empirical_mean = np.mean(samples)
    empirical_var = np.sum((samples - samples.mean())**2) / (n_t - 1)
    return empirical_mean, - empirical_var, f[ind], f_noise_var[ind]


batch_size = 50

R2 = 0.02

ratio = (np.sqrt(batch_size) + 1) / (batch_size - 1) * 0.3
# ratio = (np.sqrt(batch_size) + 1) / (batch_size - 1) * 0.2


fix_nt_flag = False
# fix_nt_flag = True

# fix_nt_value = 5
# fix_nt_value = 10
fix_nt_value = 20
if fix_nt_flag:
    R2 = noise_var_max / fix_nt_value


beta_t = np.ones(5000)
beta_t_var = np.ones(5000)

n_min, n_max = 2, 50
#### we use a fixed n_t during initialization for every queried initial input
fix_nt_init = 10
init_size = 10 # number of initial input

#### whether we estimate the max noise variance in order to use our theory-inspired choice of R2, if this is True, then the value of R2 set above will have no effect
estimate_sigma_max = True

gp_opt_schedule = 5

M_TS = 50

mean_var_obj = True
omega = 0.3

log_dir = "results_bts_red_mean_var"

run_list = np.arange(50)

for itr in run_list:
    if not fix_nt_flag:
        log_file_name = log_dir + "/res_ls_" + str(ls) + "_ls_noise_var_" + str(ls_noise_var) + \
            "_noise_range_" + str(noise_var_min) + "_" + str(noise_var_max) + "_iter_" + str(itr) + \
            "_batch_size_" + str(batch_size) + "_R2_" + str(R2) + \
            "_R2var_" + str(0) + "_n_min_" + str(n_min) + "_n_max_" + str(n_max) + "_init_" + str(init_size) + \
            "_fix_nt_init_" + str(fix_nt_init) + ".pkl"
    else:
        log_file_name = log_dir + "/res_ls_" + str(ls) + "_ls_noise_var_" + str(ls_noise_var) + \
            "_noise_range_" + str(noise_var_min) + "_" + str(noise_var_max) + "_iter_" + str(itr) + \
            "_batch_size_" + str(batch_size) + "_R2_" + str(R2) + \
            "_nt_" + str(fix_nt_value) + "_init_" + str(init_size) + ".pkl"

    if estimate_sigma_max:
        if not fix_nt_flag:
            log_file_name = log_file_name[:-4] + "_ratio_" + str(ratio) + ".pkl"
    if mean_var_obj:
        log_file_name = log_file_name[:-4] + "_omega_" + str(omega) + ".pkl"

    bo_ts = BTS_RED_mean_var(f=synth_func, pbounds={'x1':(0, 1)}, gp_opt_schedule=gp_opt_schedule, \
               log_file=log_file_name, M_TS=M_TS, \
               n_min=n_min, n_max=n_max, noise_var_func=f_noise_var, domain=domain, \
               batch_size=batch_size, R2=R2, beta_t=beta_t, \
               fix_nt_flag=fix_nt_flag, fix_nt_value=fix_nt_value, 
               use_init="inits/init_itr_" + str(itr) + "_init_" + str(init_size) + ".p", save_init=False, save_init_file=None, \
               T=max_iter, beta_t_var=beta_t_var, \
               estimate_sigma_max=estimate_sigma_max, ratio=ratio, fix_nt_init=fix_nt_init, \
               mean_var_obj=mean_var_obj, omega=omega)
    bo_ts.maximize(n_iter=max_iter, init_points=init_size)
