import torch


def HSIC(x, y, kernel_x, kernel_y, K=None, L=None):
    if x.dim() == 1:
        x = x.view(x.size(0), -1)
    if y.dim() == 1:
        y = y.view(y.size(0), -1)

    m, _ = x.shape

    if K is None:
        K = kernel_x(x)
    if L is None:
        L = kernel_y(y)

    H = torch.eye(m, dtype=torch.float32) - 1.0 / m * torch.ones((m, m), dtype=torch.float32)

    HSIC = torch.trace(torch.mm(L, torch.mm(H, torch.mm(K, H)))) / ((m - 1) ** 2)
    return HSIC


def MMR(X, A, kernel_A):
    W = kernel_A(A)

    L = (X @ X.T) * W
    n = X.shape[0]
    loss = (L.sum()) / n ** 2

    return loss
