from f012 import  *
import torch
import numpy as np
from mlcf.cf_estimator import *
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)


n_mc = np.load("nsamples_mc.npy")
budget =[1,1.5,2]
n_mc =  np.outer(budget, n_mc)
n_mc = np.ceil(n_mc).astype(int)+1

d = 2
iters = 50

TestBed_Est3 = np.zeros((iters, 3, 3))
TestBed_Est4 = np.zeros((iters, 3, 3))


for i in range(iters):

    for n_level in range(3):
        for l_level in range(3):

            print(f"i:{i} n_level:{n_level} l_level:{l_level} n_cf:{n_mc[n_level,l_level]}")

            X = torch.distributions.Uniform(0, 1).sample((n_mc[n_level,l_level], d))
            if l_level == 0:
                Y = f0(X)
            elif l_level == 1:
                Y = f10(X)
            else:
                Y = f21(X)

            Y = torch.tensor(Y)
            Y = torch.unsqueeze(Y, 1)
            Y = Y.float()

            score_X = multivariate_uniform(d, None, X)
            myCF = Simplied_CF(stein_base_kernel_MV_2, matern_25_test_kernel_boundcond, X, Y, score_X)

            if X.size()[0] < 5:
                bs = X.size()[0]
            else:
                bs = 5
                while X.size()[0] % bs == 1:
                    bs += 1

            myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0.1,lr=0.005,epochs=10, verbose=True)
            simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
            MC_est = Y.mean()

            TestBed_Est3[i, n_level,l_level] = float(simp_CF_est)
            TestBed_Est4[i,n_level,l_level] = float(MC_est)


np.save("TestBed_Est3.npy", TestBed_Est3)
np.save("TestBed_Est4.npy", TestBed_Est4)