from mlcf.score_funcs import *
from mlcf.base_kernels import *
from mlcf.stein_operators import *

class Simplied_CF(object):
    def __init__(self, prior_kernel, base_kernel, X_train, Y_train, score_tensor):

        self.prior_kernel = prior_kernel
        self.base_kernel = base_kernel
        self.X_train = X_train
        self.Y_train = Y_train
        self.score_tensor = score_tensor

        if isinstance(base_kernel(), matern_25_test_kernel) == True or isinstance(base_kernel(), matern_25_test_kernel_boundcond) == True:
            import mlcf.utils2 as utils_module
        else:
            import mlcf.utils1 as utils_module


        self.utils_module = utils_module

    def do_tune_kernelparams_negmllk(self, batch_size_tune, flag_if_use_medianheuristic=False, beta_cstkernel=1, lr=0.1, epochs=100, verbose=True):
        tune_kernelparams_negmllk_obj = self.utils_module.TuneKernelParams_mllk_MRI_singledat(self.prior_kernel, self.base_kernel,  self.X_train, self.Y_train, self.score_tensor)
        tune_kernelparams_negmllk_obj.do_optimize_logmll(batch_size_tune, flag_if_use_medianheuristic, beta_cstkernel, lr, epochs, verbose)
        optim_base_kernel_parms = torch.Tensor([tune_kernelparams_negmllk_obj.neg_mll.base_kernel_parm1, tune_kernelparams_negmllk_obj.neg_mll.base_kernel_parm2])
        self.optim_base_kernel_parms =optim_base_kernel_parms.detach()
        return optim_base_kernel_parms.detach()


    def do_closed_form_est_for_simpliedCF(self, opt_hyperparms=True):

        kernel_obj = self.prior_kernel(base_kernel=self.base_kernel)
        if opt_hyperparms == True:
            kernel_obj.base_kernel_parm1 = self.optim_base_kernel_parms[0]
            kernel_obj.base_kernel_parm2 = self.optim_base_kernel_parms[1]
        else:
            kernel_obj.base_kernel_parm1 = torch.ones(1) * 1.
            kernel_obj.base_kernel_parm2 = torch.median(self.X_train)
        k_XX = kernel_obj.cal_stein_base_kernel(self.X_train, self.X_train, self.score_tensor, self.score_tensor)
        self.Gram=k_XX
        m = self.X_train.size()[0]
        o  = (torch.ones(1, m )  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ self.Y_train )/( torch.ones(1, m)  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ torch.ones( self.X_train.size()[0], 1 )  )
        return o



    def do_nonsim_CF(self, X_te, Y_te, score_tensor):

        kernel_obj = self.prior_kernel(base_kernel=self.base_kernel)  # instantialized the class
        kernel_obj.base_kernel_parm1 = self.optim_base_kernel_parms[0]
        kernel_obj.base_kernel_parm2 = self.optim_base_kernel_parms[1]
        k_ZX =  kernel_obj.cal_stein_base_kernel(X_te, self.X_train, score_tensor, self.score_tensor)
        k_XX = kernel_obj.cal_stein_base_kernel(self.X_train, self.X_train, self.score_tensor, self.score_tensor)
        n = X_te.size()[0]
        m = self.X_train.size()[0]
        o  = (torch.ones(1, m )  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ self.Y_train )/( torch.ones(1, m)  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ torch.ones( self.X_train.size()[0], 1 )  )
        fit = k_ZX @  (k_XX + 0.001 * torch.eye(m)).inverse() @ (self.Y_train.squeeze()-o).squeeze()
        I=  (Y_te.squeeze() - fit.squeeze()).mean()
        return I


    def do_nonsim_CF_fit(self, X_te, Y_te, score_tensor):

        kernel_obj = self.prior_kernel(base_kernel=self.base_kernel)
        kernel_obj.base_kernel_parm1 = self.optim_base_kernel_parms[0]
        kernel_obj.base_kernel_parm2 = self.optim_base_kernel_parms[1]
        k_ZX =  kernel_obj.cal_stein_base_kernel(X_te, self.X_train, score_tensor, self.score_tensor)
        k_XX = kernel_obj.cal_stein_base_kernel(self.X_train, self.X_train, self.score_tensor, self.score_tensor)
        n = X_te.size()[0]
        m = self.X_train.size()[0]
        o  = (torch.ones(1, m )  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ self.Y_train )/( torch.ones(1, m)  @ (k_XX + 0.001 * torch.eye(m)).inverse() @ torch.ones( self.X_train.size()[0], 1 )  )
        fit = k_ZX @  (k_XX + 0.001 * torch.eye(m)).inverse() @ (self.Y_train.squeeze()-o).squeeze()

        I = Y_te.squeeze() - fit.squeeze()
        return I


