from mlcf.cf_estimator import *
from ODE.ODE_Solver import *
from ODE.nsamples import *
import csv
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)

def MC_Sampler(N):
    N=int(N)
    a = np.random.normal(0, 0.2, N)
    z = np.random.normal(0, 1, N)
    return np.vstack((a, z)).T

Est=np.zeros(4)

dim = 2
factor = torch.ones(1) * 1
mu = torch.zeros(dim, dtype=torch.float) + 0
var = torch.eye(dim, dtype=torch.float) * torch.tensor([0.2**2, 1])
ODE_Est = open('ODE_MC.csv', 'w')
writer = csv.writer(ODE_Est)


for i in range(100):

    for n_level in range(np.shape(MLN)[0]):

        for l in range(3):

            X = MC_Sampler(MLN[n_level, l])
            uf, uc = ODE_Solver(X, l=l)

            X = torch.tensor(X)
            X = X.float()

            Y = torch.unsqueeze(torch.tensor(uf - uc), 1)
            Y = Y.float()


            mu = torch.zeros(dim, 1)
            cov = var
            score_X = multivariate_Normal_score(mu, cov, X)
            myCF = Simplied_CF(stein_base_kernel_MV_2, rbf_kernel, X, Y,score_X)

            if X.size()[0] < 5:
                bs = X.size()[0]
            else:
                bs = 5
                while X.size()[0] % bs == 1:
                    bs += 1

            myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0, lr=0.1, epochs=10, verbose=True)

            simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
            MC_est = Y.mean()

            Est[0] = float(simp_CF_est)
            Est[1] = MC_est

            writer.writerow(Est)


ODE_Est.close()