from illustrative.fl import  *
from mlcf.cf_estimator import *
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)

n1 = [2,3,4]
n2 = [2,3,4]
iter = 50
est = np.zeros((iter,3))

for j in range(3):
    X = np.linspace(0.001, 0.999, n1[j])
    X = torch.tensor(X).unsqueeze(1).float()
    Y = f2(X)
    Y = torch.tensor(Y)
    Y = Y.float()
    print(Y.size())

    score_X = multivariate_uniform(1, None, X)
    myCF = Simplied_CF(stein_base_kernel_MV_2, matern_25_test_kernel_boundcond, X, Y, score_X)

    if X.size()[0] < 5:
        bs = X.size()[0]
    else:
        bs = 5
        while X.size()[0] % bs == 1:
            bs += 1

    myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0.1,lr=0.005,epochs=10, verbose=True)
    simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
    MC_est = Y.mean()
    for i in range(iter):
        X_te = torch.distributions.Uniform(0, 1).sample((n2[j],))
        X_te = torch.tensor(X_te).unsqueeze(1).float()
        Y_te = f2(X_te)
        Y_te = torch.tensor(Y_te)
        score_X_te = multivariate_uniform(1, None, X_te)
        fCF = myCF.do_nonsim_CF_fit( X_te, Y_te, score_X_te)
        est[i,j] = fCF.mean().item()

np.save("CF_repetition.npy", est)
