from mlcf.cf_estimator import *
from ODE.ODE_Solver import *
from ODE.nsamples import *
import csv
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)


def LHS_Sampler(N):
    N=int(N)
    X = ODE_Sampler(N,method="LHS")
    return X


ODE_Est = open('ODE_LHS.csv', 'w')



writer = csv.writer(ODE_Est)

Est=np.zeros(2)

dim = 2
factor = torch.ones(1) * 1
mu = torch.zeros(dim, dtype=torch.float) + 0
var = torch.eye(dim, dtype=torch.float) * torch.tensor([0.2**2, 1])



for i in range(100):


    for n_level in range(np.shape(MLN)[0]):

        for l in range(3):

            X = LHS_Sampler(MLN[n_level, l])
            uf, uc = ODE_Solver(X, l=l)

            X = torch.tensor(X)
            X = X.float()

            Y = torch.unsqueeze(torch.tensor(uf - uc), 1)
            Y = Y.float()


            mu = torch.zeros(dim, 1)
            cov = var
            score_X = multivariate_Normal_score(mu, cov, X)
            myCF = Simplied_CF(stein_base_kernel_MV_2, rbf_kernel, X, Y,
                               score_X)
            if X.size()[0] < 5:
                bs = X.size()[0]
            else:
                bs = 5
                while X.size()[0] % bs == 1:
                    bs += 1

            myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False,
                                              beta_cstkernel=0, lr=0.1, epochs=10, verbose=True)
            simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
            MC_est = Y.mean()

            Est[0] = float(simp_CF_est)

            writer.writerow(Est)


ODE_Est.close()
