# Run command example (N=10, n_samples=2, n_points=20, methods=['MVR', 'GP_PI']):
#                  python3 main_hartman.py -N 10 -n_samples 2 -n_points 20 -methods MVR GP_PI
# Default values:
#                 - n_samples = 1
#                 - N = 50
#                 - n_points = 100
#                 - methods = ['Chowdhury_Gopalan', 'MVR', 'GP_PI', 'GP_EI']
import os
import time
import numpy as np
import sklearn
from sklearn.gaussian_process.kernels import RBF, Matern
import matplotlib.pyplot as plt

from function import f_hartman3_min, make_params
from algorithms import GP_UCB, MVR, run_GP_UCB, run_MVR, run_GP_PI, run_GP_EI

from functools import partial
import argparse


## Hartman function ##


# Retrieve arguments
parser = argparse.ArgumentParser()
parser.add_argument('-N', type=int)
parser.add_argument('-n_samples', type=int)
parser.add_argument('-n_points', type=int)
parser.add_argument('-methods', '--methods', nargs='+')
args = parser.parse_args()
my_dict = {'N': args.N, 'n_samples': args.n_samples, 'methods': args.methods}
print('Input args: ', my_dict)

# Make data repo
data_repo = './outputs_hartman'
if not os.path.exists(data_repo):
    os.mkdir(data_repo)

# Set inputs
d = 3
n_points = args.n_points or 100
xD = np.linspace(0, 1, n_points)
yD = np.linspace(0, 1, n_points)
zD = np.linspace(0, 1, n_points)

# Set the kernel
squared_exp_kernel = RBF(length_scale=0.2, length_scale_bounds=(1e-7, 1e7))
nu_matern = 2.5
matern_kernel = Matern(length_scale=0.2, nu=nu_matern, length_scale_bounds=(1e-7, 1e7))

# Set number of samples
n_samples = args.n_samples or 1

# Set delta
delta = 0.1

# Set R and lambda for both kernels
coords = np.array(np.meshgrid(xD, yD, zD)).reshape(3, -1)
R_se, lamb_se = make_params(coords, f_hartman3_min, squared_exp_kernel)
R_matern, lamb_matern = make_params(coords, f_hartman3_min, matern_kernel)

# Set B**2 to <Kf,f> (computed a priori)
B_se = 1307120
B_matern = 1298985

# Set N
N = args.N or 50
print('\n N: ', N, ', n_points: ', n_points, ', n_samples: ', n_samples)

# Theoretical bound on gamma_n
gamma_n_se = lambda n:(np.log(n))**(d+1)
gamma_n_matern = lambda n: np.log(n) * n ** ( (d*(d+1)) / (2*nu_matern* + d*(d+1)) )

# Set prior mu and sigma
mu_prior = 0
sigma_prior = 1

# Set methods for run
method_names = args.methods or ['Chowdhury_Gopalan', 'MVR', 'GP_PI', 'GP_EI']

def run_simulation_samples(kernel, f, lamb, R, B, gamma_n):

    print('\n\n RUNNING FOR KERNEL: ', kernel)
    # Sigma for the noise epsilon ~ N(0, sigma**2)
    v = 1.
    sigma_noise = np.sqrt(1e-2)#lamb*v**2)
    
    beta_Chowdhury_Gopalan = lambda n, t: B + R* np.sqrt(2 * (gamma_n(t) + 1 + np.log(1/delta)))
    zeta = 0.1
    
    gp_ucb_dict = {}
    xstar_dict = {}
    f_xstar_dict = {}


    run_fct = run_GP_UCB
    for name in method_names:
    
        print('\n ##NAME: ', name, '##')
        beta_n = lambda n: partial(beta_Chowdhury_Gopalan, n)
    
        if name == 'MVR':
            run_fct = run_MVR
        elif name == 'GP_EI':
            run_fct = run_GP_EI
        elif name == 'GP_PI':
            run_fct = run_GP_PI
        param_argmax = zeta
    
        gp_ucb_list = []
        xstar_list = []
        f_xstar_list = []
    
        for n in range(n_samples):
            if run_fct == run_GP_UCB:
                param_argmax = beta_n(n)
    
            time_n = time.time()
            print('n: ', n,'/', n_samples)
            gp_ucb, xstar, f_xstar = run_fct(f, N, param_argmax, kernel, sigma_noise,
                                             mu_prior=mu_prior, sigma_prior=sigma_prior,
                                             coords=(xD, yD, zD))
            print('time elapsed: ', time.time() - time_n)
            gp_ucb_list.append(gp_ucb)
            xstar_list.append(xstar)
            f_xstar_list.append(f_xstar)
    
        gp_ucb_dict[name] = gp_ucb_list
        xstar_dict[name] = xstar_list
        f_xstar_dict[name] = f_xstar_list
    return gp_ucb_dict, xstar_dict, f_xstar_dict


time_t = time.time()
gp_ucb_SE, xstar_SE, f_xstar_SE = run_simulation_samples(squared_exp_kernel, f_hartman3_min, lamb_se,
                                                         R_se, B_se, gamma_n_se)
print('\n Runs for SE: ', time.time() - time_t)

time_t = time.time()
gp_ucb_Matern, xstar_Matern, f_xstar_Matern = run_simulation_samples(matern_kernel, f_hartman3_min, lamb_matern,
                                                                     R_matern, B_matern, gamma_n_matern)
print('\n Runs for Matern: ', time.time() - time_t)


## Post-processing ##

# Save Xs in files
for name in method_names:
    for n in range(n_samples):

        name_file = data_repo+'/X_SE_'+name+'_n_'+str(n)+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
        np.array(gp_ucb_SE[name][n].X).tofile(name_file)
        name_file = data_repo+'/X_Matern_'+name+'_n_'+str(n)+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
        np.array(gp_ucb_Matern[name][n].X).tofile(name_file)

# Compute average simple regret
average_simple_regret_SE_dict = dict(zip(method_names, (np.zeros((N,)),)*len(method_names)))
average_simple_regret_Matern_dict = dict(zip(method_names, (np.zeros((N,)),)*len(method_names)))

for name in method_names:

    average_simple_regret_SE = np.zeros((N,))
    average_simple_regret_Matern = np.zeros((N,))
    for n in range(n_samples):
    
        # Squared Exponential Kernel # 
    
        # f(x_1^{n}), ..., f(x_N^{n}); values for the sample n
        f_n = np.array(tuple(map(gp_ucb_SE[name][n].f, gp_ucb_SE[name][n].X)))
        # f(x_star^{n}) - f(x_i^{n}), for i=1,...,N
        average_simple_regret_SE += (f_xstar_SE[name][n] - f_n)
    
        # Matern Kernel #
    
        # f(x_1^{n}), ..., f(x_N^{n}); values for the sample n
        f_n = np.array(tuple(map(gp_ucb_Matern[name][n].f, gp_ucb_Matern[name][n].X)))
        # f(x_star^{n}) - f(x_i^{n}), for i=1,...,N
        average_simple_regret_Matern += (f_xstar_Matern[name][n] - f_n)
    
    average_simple_regret_SE /= n_samples
    average_simple_regret_SE_dict.update({name:average_simple_regret_SE})
    
    average_simple_regret_Matern /= n_samples
    average_simple_regret_Matern_dict.update({name:average_simple_regret_Matern})


# Save average simple regret in files
for name in method_names:
    name_file = data_repo+'/Average_simple_regret_SE_'+name+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
    average_simple_regret_SE_dict[name].tofile(name_file)
    name_file = data_repo+'/Average_simple_regret_Matern_'+name+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
    average_simple_regret_Matern_dict[name].tofile(name_file)
