import sys, os
sys.path.append('/scratch/user/mzfan/MFBO_v2')
import torch
import math
import gpytorch
import matplotlib.pyplot as plt
from pyro.infer.mcmc import NUTS, MCMC, HMC
from gpytorch.likelihoods import GaussianLikelihood
import pyro
import pickle
from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior, MultivariateNormalPrior
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
import pyro.distributions as dist
from torch.nn import Module
import torch.nn as nn
import torch.optim as optim
import time
import scipy
from scipy.optimize import minimize
import time
from funcs import sin_target, band_gap_target
from surrogates import NVSurrogate, SepSurrogate
from experiment import experiment_num, experiment_continious, experiment_two_step
from acquisition_function import UCB_NVUCB, UCB, UCB_NUCB


if __name__ == "__main__":
    for i in range(1000):
        init_x_sets = torch.rand([10, 4, 1])
        init_x_ind_sets = torch.zeros([10, 4], dtype=torch.long)
        init_x_ind_sets[:, 2:] = torch.ones_like(init_x_ind_sets[:, 2:])
        sin_target = torch.sin(2*torch.pi*init_x_sets)
        maximum = torch.max(sin_target)
        mean = torch.mean(sin_target)
        print(maximum, mean)
        if maximum < 0.75:
            print('!!!!!!!')
