import sys, os
sys.path.append('/scratch/user/mzfan/MFBO_v2')
sys.path.append('/Users/mzfan/MFBO_v2')
import torch
import math
import gpytorch
import matplotlib.pyplot as plt
from pyro.infer.mcmc import NUTS, MCMC, HMC
from gpytorch.likelihoods import GaussianLikelihood
import pyro
import pickle
from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior, MultivariateNormalPrior
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
import pyro.distributions as dist
from torch.nn import Module
import torch.nn as nn
import torch.optim as optim
import time
import scipy
from scipy.optimize import minimize
import time
from funcs import sin_target, band_gap_target
from surrogates import NVSurrogate, SepSurrogate
from experiment import experiment_num, experiment_continious, experiment_two_step
from acquisition_function import UCB_NVUCB, UCB, UCB_NUCB
import argparse

if __name__ == "__main__":
    t1 = time.time()

    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('index', help="index", type=int)
    parser.add_argument('name', help="name", type=str)
    args = parser.parse_args()
    ex_id = args.index
    name = args.name

    # name = 'NVUCB0'
    # ex_id = 0
    iteration_num = 20


    load_dir = './Toy_example_sin/'
    save_dir = './Toy_example_sin/saves/'
    target = sin_target()
    init_x = torch.load(load_dir + 'init_x.ts')[ex_id, :, :]
    init_x_index = torch.load(load_dir + 'init_x_ind.ts')[ex_id, :]
    init_x_index = init_x_index.to(torch.long)
    print(name)
    if name == 'NVUCB':
        surrogate_NV = NVSurrogate(2, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_NV = UCB_NVUCB()
        experiment_NV = experiment_continious(target, surrogate_NV, acqui_NV)
        experiment_NV.initialize_given(init_x, init_x_index)
        experiment_NV.run_iterations(iteration_num)
        experiment_NV.save(save_dir+f'ex_{ex_id}_res_NVUCB')
        print(time.time() - t1)
    elif name == 'NUCB':
        surrogate_NU = NVSurrogate(2, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_NU = UCB_NUCB()
        experiment_NU = experiment_continious(target, surrogate_NU, acqui_NU)
        experiment_NU.initialize_given(init_x, init_x_index)
        experiment_NU.run_iterations(iteration_num)
        experiment_NU.save(save_dir+f'ex_{ex_id}_res_NUCB')
        print(time.time() - t1)
    elif name == 'UCB':
        surrogate_Sep = SepSurrogate(2, 1, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_U = UCB()
        experiment_U = experiment_continious(target, surrogate_Sep, acqui_U)
        experiment_U.initialize_given(init_x, init_x_index)
        experiment_U.run_iterations(iteration_num)
        experiment_U.save(save_dir + f'ex_{ex_id}_res_UCB')
        print(time.time() - t1)
    elif name == 'UCBTS':
        surrogate_TS = NVSurrogate(2, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_TS = UCB()
        experiment_TS = experiment_two_step(target, surrogate_TS, acqui_TS)
        experiment_TS.initialize_given(init_x, init_x_index)
        experiment_TS.run_iterations(iteration_num)
        experiment_TS.save(save_dir+f'ex_{ex_id}_res_UCBTS')
        print(time.time() - t1)

    elif name == 'NVUCB0':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=0)
        surrogate_NV = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_NV = UCB_NVUCB()
        experiment_NV = experiment_continious(target, surrogate_NV, acqui_NV)
        experiment_NV.initialize_given(init_x, init_x_index)
        experiment_NV.run_iterations(iteration_num)
        experiment_NV.save(save_dir + f'ex_{ex_id}_res_NVUCB0')

    elif name == 'NUCB0':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=0)
        surrogate_N = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_N = UCB_NUCB()
        experiment_N = experiment_continious(target, surrogate_N, acqui_N)
        experiment_N.initialize_given(init_x, init_x_index)
        experiment_N.run_iterations(iteration_num)
        experiment_N.save(save_dir + f'ex_{ex_id}_res_NUCB0')

    elif name == 'UCB0':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=0)
        surrogate_Sep = SepSurrogate(1, 1, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_U = UCB()
        experiment_U = experiment_continious(target, surrogate_Sep, acqui_U)
        experiment_U.initialize_given(init_x, init_x_index)
        experiment_U.run_iterations(iteration_num)
        experiment_U.save(save_dir + f'ex_{ex_id}_res_UCB0')
        print(time.time() - t1)


    elif name == 'NVUCB1':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=1)
        surrogate_NV = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_NV = UCB_NVUCB()
        experiment_NV = experiment_continious(target, surrogate_NV, acqui_NV)
        experiment_NV.initialize_given(init_x, init_x_index)
        experiment_NV.run_iterations(iteration_num)
        experiment_NV.save(save_dir + f'ex_{ex_id}_res_NVUCB1')

    elif name == 'NUCB1':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=1)
        surrogate_N = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_N = UCB_NUCB()
        experiment_N = experiment_continious(target, surrogate_N, acqui_N)
        experiment_N.initialize_given(init_x, init_x_index)
        experiment_N.run_iterations(iteration_num)
        experiment_N.save(save_dir + f'ex_{ex_id}_res_NUCB1')

    elif name == 'UCB1':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=1)
        surrogate_Sep = SepSurrogate(1, 1, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_U = UCB()
        experiment_U = experiment_continious(target, surrogate_Sep, acqui_U)
        experiment_U.initialize_given(init_x, init_x_index)
        experiment_U.run_iterations(iteration_num)
        experiment_U.save(save_dir + f'ex_{ex_id}_res_UCB1')
        print(time.time() - t1)

    elif name == 'UCBC0':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=0)
        surrogate_U = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_U = UCB()
        experiment_U = experiment_continious(target, surrogate_U, acqui_U)
        experiment_U.initialize_given(init_x, init_x_index)
        experiment_U.run_iterations(iteration_num)
        experiment_U.save(save_dir + f'ex_{ex_id}_res_UCBC0')
        print(time.time() - t1)

    elif name == 'UCBC1':
        init_x_index = torch.zeros_like(init_x_index, dtype=torch.long)
        target = sin_target(fidelity_fix=1)
        surrogate_U = NVSurrogate(1, 1, num_samples=128, lengthscale=torch.tensor([[0.25]]),
                                   outputscale=torch.tensor([.774]), mean=torch.nn.Parameter(torch.tensor([0.])))
        acqui_U = UCB()
        experiment_U = experiment_continious(target, surrogate_U, acqui_U)
        experiment_U.initialize_given(init_x, init_x_index)
        experiment_U.run_iterations(iteration_num)
        experiment_U.save(save_dir + f'ex_{ex_id}_res_UCBC1')
        print(time.time() - t1)

    else:
        print(name)
        print(name=='UCB')
        print("NO! RIGHT! MODEL!")
