import numpy as np
import torch
import math
import matplotlib.pyplot as plt
import pickle
import time
from util import Train_model_naive, Train_model_our_2, UCB_our_nfc, UCB_naive, UCB_our, optimize_acquisition_LB, BO_performe_2
from bayesian_optimisation import ground_truth_query
from target_functions import Branin, Hartmann, Levy
from multiprocessing import Pool
import multiprocessing
import argparse

def query_br(tr_x, index_x, x_bound):
    a = torch.tensor([[[0.5, 0.5]], [[-0.5, -0.50]]])
    b = torch.tensor([[0], [1]])

    tr_x_t = (x_bound[1,:] - x_bound[0,:]).unsqueeze(0)*tr_x + x_bound[0,:].unsqueeze(0)
    tr_y_gt = Branin(tr_x_t)
    noise_level = []
    for i in range(index_x.shape[0]):
        ind_i = index_x[i]
        noise_level.append(torch.mm(tr_x[i].unsqueeze(0), a[ind_i].T).squeeze() + b[ind_i])
    tr_y = tr_y_gt/100 + 1 + (torch.randn(tr_x.shape[0])) * (torch.cat(noise_level))
    return tr_y, tr_y_gt/100 + 1

def query_hm(tr_x, index_x, x_bound):
    a = torch.tensor([[[0.5, 0.5, 0.5, 0, 0, 0]], [[-0.5, -0.5, -0.5, 0, 0, 0]]])
    b = torch.tensor([[0], [1]])
    tr_y_gt = Hartmann(tr_x)
    noise_level = []
    for i in range(index_x.shape[0]):
        ind_i = index_x[i]
        noise_level.append(torch.mm(tr_x[i].unsqueeze(0), a[ind_i].T).squeeze() + b[ind_i])
    tr_y = tr_y_gt - 2 + (torch.randn(tr_x.shape[0])) * (torch.cat(noise_level))
    return tr_y, tr_y_gt - 2

def query_lv(tr_x, index_x, x_bound):
    a = torch.tensor([[[0.5, 0.5, 0]], [[-0.5, -0.5, 0]]])
    b = torch.tensor([[0], [1]])
    tr_x_t = (x_bound[1, :] - x_bound[0, :]).unsqueeze(0) * tr_x + x_bound[0, :].unsqueeze(0)
    tr_y_gt = Levy(tr_x_t)
    noise_level = []
    for i in range(index_x.shape[0]):
        ind_i = index_x[i]
        noise_level.append(torch.mm(tr_x[i].unsqueeze(0), a[ind_i].T).squeeze() + b[ind_i])
    tr_y = tr_y_gt/40-1 + (torch.randn(tr_x.shape[0])) * (torch.cat(noise_level))
    return tr_y, tr_y_gt/40-1
#
if __name__ == "__main__":
    t = time.time()
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('index', help="index", type=int)
    parser.add_argument('name', help="name", type=str)
    args = parser.parse_args()
    i = args.index
    f = args.name
    iters = 60
    dirc = './BO_syn_12_9_2'
    if f == "hm":
        with open(dirc+ '/BO_hm_initals.pickle', 'rb') as handle:
            initial = pickle.load(handle)
        saves_1 = BO_performe_2(query_hm, 2, 6, [1, 1], Train_model_our_2, UCB_our, iters, num_samples=64,
                                initial_samples=initial[i])
        saves_2 = BO_performe_2(query_hm, 2, 6, [1, 1], Train_model_our_2, UCB_our_nfc, iters, num_samples=64,
                                initial_samples=initial[i])
        saves_3 = BO_performe_2(query_hm, 2, 6, [1, 1], Train_model_naive, UCB_naive, iters, num_samples=64,
                                    initial_samples=initial[i])
        saves = [saves_1, saves_2, saves_3]
        with open(dirc + '/BO_hm'+str(i)+'.pickle', 'wb') as handle:
            pickle.dump(saves, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open(dirc + '/BO_hm'+str(i)+'.pickle', 'rb') as handle:
            b = pickle.load(handle)

    if f == "br":
        with open(dirc + '/BO_br_initals.pickle', 'rb') as handle:
            initial = pickle.load(handle)
        saves_1 = BO_performe_2(query_br, 2, 2, [1, 1], Train_model_our_2, UCB_our, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-5, 0], [10, 15]]))
        saves_2 = BO_performe_2(query_br, 2, 2, [1, 1], Train_model_our_2, UCB_our_nfc, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-5, 0], [10, 15]]))
        saves_3 = BO_performe_2(query_br, 2, 2, [1, 1], Train_model_naive, UCB_naive, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-5, 0], [10, 15]]))
        saves = [saves_1, saves_2, saves_3]
        with open(dirc + '/BO_br' + str(i) + '.pickle', 'wb') as handle:
            pickle.dump(saves, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open(dirc + '/BO_br' + str(i) + '.pickle', 'rb') as handle:
            b = pickle.load(handle)

    if f == "lv":
        with open(dirc + '/BO_lv_initals.pickle', 'rb') as handle:
            initial = pickle.load(handle)
        saves_1 = BO_performe_2(query_lv, 2, 3, [1, 1], Train_model_our_2, UCB_our, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-10, -10, -10], [10, 10, 10]]))
        saves_2 = BO_performe_2(query_lv, 2, 3, [1, 1], Train_model_our_2, UCB_our_nfc, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-10, -10, -10], [10, 10, 10]]))
        saves_3 = BO_performe_2(query_lv, 2, 3, [1, 1], Train_model_naive, UCB_naive, iters, num_samples=64,
                                initial_samples=initial[i], x_bound=torch.tensor([[-10, -10, -10], [10, 10, 10]]))
        saves = [saves_1, saves_2, saves_3]
        with open(dirc + '/BO_lv' + str(i) + '.pickle', 'wb') as handle:
            pickle.dump(saves, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open(dirc + '/BO_lv' + str(i) + '.pickle', 'rb') as handle:
            b = pickle.load(handle)
