from illustrative.fl import  *
import torch
import numpy as np
from mlcf.cf_estimator import *
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)


n1 = [[6,2,2],[7,6,2],[12,9,3]]
n2 = [[5,2,1],[7,5,2],[11,8,2]]
iter = 50
L = 3
est = np.zeros((iter,3,L))

for l in range(3):
    for j in range(3):
        X = np.linspace(0.001, 0.999, n1[j][l])
        X = torch.tensor(X).unsqueeze(1).float()
        if l == 0:
            Y=f0(X)
        elif l == 1:
            Y=f1(X)-f0(X)
        else:
            Y = f2(X)-f1(X)
        Y = torch.tensor(Y)
        Y = Y.float()

        score_X = multivariate_uniform(1, None, X)
        myCF = Simplied_CF(stein_base_kernel_MV_2, matern_25_test_kernel_boundcond, X, Y, score_X)
        if X.size()[0] < 5:
            bs = X.size()[0]
        else:
            bs = 5
            while X.size()[0] % bs == 1:
                bs += 1

        myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0.1,lr=0.005,epochs=10, verbose=True)
        simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
        MC_est = Y.mean()
        for i in range(iter):
            X_te = torch.distributions.Uniform(0, 1).sample((n2[j][l],))
            X_te = torch.tensor(X_te).unsqueeze(1).float()
            if l == 0:
                Y_te = f0(X_te)
            elif l == 1:
                Y_te = f1(X_te) - f0(X_te)
            else:
                Y_te = f2(X_te) - f1(X_te)
            Y_te = torch.tensor(Y_te)
            score_X_te = multivariate_uniform(1, None, X_te)
            fCF = myCF.do_nonsim_CF_fit( X_te, Y_te, score_X_te)
            est[i,j,l] = fCF.mean().item()

np.save("MLCF_repetition.npy", est)
