# Run command example (N=10, n_samples=25, n_points=50, methods=['MVR', 'GP_PI']):
#                  python3 main_synthetic_test_fct.py -N 10 -n_samples 25 -n_points 50 -methods MVR GP_PI
# Default values:
#                 - n_samples = 25
#                 - N = 100
#                 - n_points = 100
#                 - methods = ['Chowdhury_Gopalan', 'MVR', 'GP_PI', 'GP_EI']
import os
import time
import numpy as np
import sklearn
from sklearn.gaussian_process.kernels import RBF, Matern
import matplotlib.pyplot as plt

from function import f_demo_1d, f_synthetic_test_chowdhury_gopalan
from algorithms import GP_UCB, MVR, run_GP_UCB, run_MVR, run_GP_PI, run_GP_EI

from functools import partial
import argparse


## RKHS functions ##


# Retrieve arguments
parser = argparse.ArgumentParser()
parser.add_argument('-N', type=int)
parser.add_argument('-n_samples', type=int)
parser.add_argument('-n_points', type=int)
parser.add_argument('-methods', '--methods', nargs='+')
args = parser.parse_args()
my_dict = {'N': args.N, 'n_samples': args.n_samples, 'methods': args.methods}
print('Input args: ', my_dict)

# Make data repo 
data_repo = './outputs'
if not os.path.exists(data_repo):
    os.mkdir(data_repo)

# Set inputs
d = 1
n_points = args.n_points or 100
xD = np.linspace(0, 1, n_points)

# Set the kernel
squared_exp_kernel = 1.0 * RBF(length_scale=0.2)
nu_matern = 2.5
matern_kernel = 1.0 * Matern(length_scale=0.2, nu=nu_matern)

# Set number of samples
n_samples = args.n_samples or 25

# Set delta
delta = 0.1

# Get RKHS functions, R, lambda and B for both kernels
y_se, R_se, lamb_se, B_se = f_synthetic_test_chowdhury_gopalan(xD, squared_exp_kernel, n_samples)
y_matern, R_matern, lamb_matern, B_matern = f_synthetic_test_chowdhury_gopalan(xD, matern_kernel, n_samples)

# Display RKHS functions
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15,15))
ax1.plot(xD, np.rollaxis(y_se, -1), lw=1)
ax1.set_title('Square exponential kernel')
ax2.plot(xD, np.rollaxis(y_matern, -1), lw=1)
ax2.set_title('Matern kernel')
plt.show()

# Make function objects
f_se = [(lambda x, n=n: y_se[n, np.where(xD==x[0, ...])[0].squeeze()]) for n in range(n_samples)]
f_matern = [lambda x, n=n: y_matern[n, np.where(xD==x[0, ...])[0].squeeze()] for n in range(n_samples)]

# Set N
N = args.N or 100
print('\n N: ', N, ', n_points: ', n_points, ', n_samples: ', n_samples)

# Theoretical bound on gamma_n
gamma_n_se = lambda n:(np.log(n))**(d+1)
gamma_n_matern = lambda n: np.log(n) * n ** ( (d*(d+1)) / (2*nu_matern* + d*(d+1)) )

# Set prior mu and sigma
mu_prior = 0
sigma_prior = 1

# Set methods for run
method_names = args.methods or ['Chowdhury_Gopalan', 'MVR', 'GP_PI', 'GP_EI']

def run_simulation_samples(kernel, f, lamb, R, B, gamma_n):

    print('\n\n RUNNING FOR KERNEL: ', kernel)
    # Sigma for the noise epsilon ~ N(0, sigma**2)
    v = 1.
    sigma_noise = np.sqrt(lamb*v**2)
    
    beta_Chowdhury_Gopalan = lambda n, t: B[n] + R[n]* np.sqrt(2 * (gamma_n(t) + 1 + np.log(1/delta)))
    zeta = 0.1
    
    gp_ucb_dict = {}
    xstar_dict = {}
    f_xstar_dict = {}


    run_fct = run_GP_UCB
    for name in method_names:
    
        print('\n ##NAME: ', name, '##')
        beta_n = lambda n: partial(beta_Chowdhury_Gopalan, n)
    
        if name == 'MVR':
            run_fct = run_MVR
        elif name == 'GP_EI':
            run_fct = run_GP_EI
        elif name == 'GP_PI':
            run_fct = run_GP_PI
        param_argmax = zeta
    
        gp_ucb_list = []
        xstar_list = []
        f_xstar_list = []
    
        for n in range(n_samples):
            if run_fct == run_GP_UCB:
                param_argmax = beta_n(n)
    
            time_n = time.time()
            print('n: ', n,'/', n_samples)
            gp_ucb, xstar, f_xstar = run_fct(f[n], N, param_argmax, kernel, sigma_noise[n],
                                             mu_prior=mu_prior, sigma_prior=sigma_prior,
                                             coords=(xD,))
            print('time elapsed: ', time.time() - time_n)
            gp_ucb_list.append(gp_ucb)
            xstar_list.append(xstar)
            f_xstar_list.append(f_xstar)
    
        gp_ucb_dict[name] = gp_ucb_list
        xstar_dict[name] = xstar_list
        f_xstar_dict[name] = f_xstar_list
    return gp_ucb_dict, xstar_dict, f_xstar_dict


time_t = time.time()
gp_ucb_SE, xstar_SE, f_xstar_SE = run_simulation_samples(squared_exp_kernel, f_se, lamb_se,
                                                         R_se, B_se, gamma_n_se)
print('\n Runs for SE: ', time.time() - time_t)

time_t = time.time()
gp_ucb_Matern, xstar_Matern, f_xstar_Matern = run_simulation_samples(matern_kernel, f_matern, lamb_matern,
                                                                     R_matern, B_matern, gamma_n_matern)
print('\n Runs for Matern: ', time.time() - time_t)


## Post-processing ##

# Save Xs in files
for name in method_names:
    for n in range(n_samples):

        name_file = data_repo+'/X_SE_'+name+'_n_'+str(n)+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
        np.array(gp_ucb_SE[name][n].X).tofile(name_file)
        name_file = data_repo+'/X_Matern_'+name+'_n_'+str(n)+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
        np.array(gp_ucb_Matern[name][n].X).tofile(name_file)

# Compute average simple regret
average_simple_regret_SE_dict = dict(zip(method_names, (np.zeros((N,)),)*len(method_names)))
average_simple_regret_Matern_dict = dict(zip(method_names, (np.zeros((N,)),)*len(method_names)))

for name in method_names:

    average_simple_regret_SE = np.zeros((N,))
    average_simple_regret_Matern = np.zeros((N,))
    for n in range(n_samples):
    
        # Squared Exponential Kernel # 
    
        # f(x_1^{n}), ..., f(x_N^{n}); values for the sample n
        f_n = np.array(tuple(map(gp_ucb_SE[name][n].f, gp_ucb_SE[name][n].X)))
        # f(x_star^{n}) - f(x_i^{n}), for i=1,...,N
        average_simple_regret_SE += (f_xstar_SE[name][n] - f_n)
    
        # Matern Kernel #
    
        # f(x_1^{n}), ..., f(x_N^{n}); values for the sample n
        f_n = np.array(tuple(map(gp_ucb_Matern[name][n].f, gp_ucb_Matern[name][n].X)))
        # f(x_star^{n}) - f(x_i^{n}), for i=1,...,N
        average_simple_regret_Matern += (f_xstar_Matern[name][n] - f_n)
    
    average_simple_regret_SE /= n_samples
    average_simple_regret_SE_dict.update({name:average_simple_regret_SE})
    
    average_simple_regret_Matern /= n_samples
    average_simple_regret_Matern_dict.update({name:average_simple_regret_Matern})


# Save average simple regret in files
for name in method_names:
    name_file = data_repo+'/Average_simple_regret_SE_'+name+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
    average_simple_regret_SE_dict[name].tofile(name_file)
    name_file = data_repo+'/Average_simple_regret_Matern_'+name+'_N_'+str(N)+'_nsamples_'+str(n_samples)+'.csv'
    average_simple_regret_Matern_dict[name].tofile(name_file)
