import torch
from mlcf.stein_operators import *
from mlcf.score_funcs import *



def f_x_batch(X, constant_1, constant_vec2, constant_3, base_kernel_class,
              kernel_parm1=torch.tensor([1.0]), kernel_parm2=torch.tensor([1.0])):

    assert X.ndim == 2 and constant_vec2.ndim == 1, "X must be (n,d), constant_vec2 must be (d,)"
    n, d = X.shape
    assert d == constant_vec2.shape[0], "Dimension mismatch"

    Z = constant_vec2.unsqueeze(0).repeat(n, 1)

    score_X = multivariate_uniform(dim=d, nullparm=None, X=X)
    score_Z = multivariate_uniform(dim=d, nullparm=None, X=Z)


    kernel_obj = stein_base_kernel_MV_2(base_kernel=base_kernel_class)
    kernel_obj.base_kernel_parm1 = kernel_parm1
    kernel_obj.base_kernel_parm2 = kernel_parm2


    K = kernel_obj.cal_stein_base_kernel(X, Z, score_X, score_Z) + constant_3

    return constant_1 * K.diag()

