from TestBed.f012 import  *
import torch
import numpy as np
from mlcf.cf_estimator import *
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)


budget = [1,1.5,2]
n =  np.round(200/3)*np.array(budget)
n = np.ceil(n).astype(int)

d = 2
iters = 50

TestBed_Est5 = np.zeros((iters, 3))
TestBed_Est6 = np.zeros((iters, 3))


for i in range(iters):


    for n_level in range(3):

        print(f"i:{i} n_level:{n_level} n:{n[n_level]}")

        X = torch.distributions.Uniform(0, 1).sample((n[n_level], d))
        Y = f0(X) + f10(X) + f21(X)
        Y = torch.tensor(Y)
        Y = torch.unsqueeze(Y, 1)
        Y = Y.float()

        score_X = multivariate_uniform(d, None, X)
        myCF = Simplied_CF(stein_base_kernel_MV_2, matern_25_test_kernel_boundcond, X, Y, score_X)
        if X.size()[0] < 5:
            bs = X.size()[0]
        else:
            bs = 5
            while X.size()[0] % bs == 1:
                bs += 1

        myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0.1,lr=0.005,epochs=10, verbose=True)
        simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
        MC_est = Y.mean()
        TestBed_Est5[i, n_level] = float(simp_CF_est)
        TestBed_Est6[i,n_level] = float(MC_est)


np.save("TestBed_Est5.npy", TestBed_Est5)
np.save("TestBed_Est6.npy", TestBed_Est6)