from illustrative.fl import  *
import torch
import numpy as np
from mlcf.cf_estimator import *
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)

X = np.linspace(0.001, 0.999, 4)
X = torch.tensor(X).unsqueeze(1).float()
Y = f2(X)-f1(X)
Y = torch.tensor(Y)
Y = Y.float()
print(Y.size())

score_X = multivariate_uniform(1, None, X)
myCF = Simplied_CF(stein_base_kernel_MV_2, matern_25_test_kernel_boundcond, X, Y, score_X)

if X.size()[0] < 5:
    bs = X.size()[0]
else:
    bs = 5
    while X.size()[0] % bs == 1:
        bs += 1


myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0.1,lr=0.005,epochs=10, verbose=True)
simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
MC_est = Y.mean()
X_te = np.linspace(0.001, 0.999, 300)
X_te = torch.tensor(X_te)
X_te = torch.sort(X_te).values
X_te = torch.tensor(X_te).unsqueeze(1).float()
Y_te = f2(X_te) - f1(X_te)
Y_te = torch.tensor(Y_te)
score_X_te = multivariate_uniform(1, None, X_te)
fCF = myCF.do_nonsim_CF_fit( X_te, Y_te, score_X_te)
torch.save(fCF, 'fMLCF_f2.pt')
torch.save(X_te, 'X_mlcf_f2.pt')


