from ODE.ODE_Solver import *
from mlcf.cf_estimator import *
from ODE.nsamples import *
import csv
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(200)

Est=np.zeros(2)

dim = 2
mu = torch.zeros(dim, dtype=torch.float)
var = torch.eye(dim, dtype=torch.float) * torch.tensor([0.2**2, 1])

ODE_Est = open('./ODE_CF.csv', 'w')
writer = csv.writer(ODE_Est)


for i in range(100):


    for n_level in range(np.shape(N)[0]):

        X = ODE_Sampler(N=int(N[n_level]))
        Pf, Pc = ODE_Solver(X, l=2)

        X = torch.tensor(X)
        X = X.float()

        Y = torch.unsqueeze(torch.tensor(Pf), 1)
        Y = Y.float()

        mu = torch.zeros(dim, 1)
        cov = var
        score_X = multivariate_Normal_score(mu, cov, X)
        myCF = Simplied_CF(stein_base_kernel_MV_2, rbf_kernel, X, Y,score_X)
        if X.size()[0] < 5:
            bs = X.size()[0]
        else:
            bs = 5
            while X.size()[0] % bs == 1:
                bs += 1

        myCF.do_tune_kernelparams_negmllk(batch_size_tune=bs, flag_if_use_medianheuristic=False, beta_cstkernel=0, lr=0.1, epochs=20, verbose=True)
        simp_CF_est = myCF.do_closed_form_est_for_simpliedCF()
        MC_est = Y.mean()

        Est[0] = float(simp_CF_est)

        writer.writerow(Est)



ODE_Est.close()