'''Implementation of kernel functions.'''

import torch

eps = 1e-12


def euclidean(samples, centers, squared=True):
    '''Calculate the pointwise distance.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        squared: boolean.

    Returns:
        pointwise distances (n_sample, n_center).
    '''
    samples_norm = torch.sum(samples ** 2, dim=1, keepdim=True)
    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = torch.sum(centers ** 2, dim=1, keepdim=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances



def hilbert(samples, centers, d, epsilon=1e-10):
    '''hilbert kernel for batched inputs.

    Args:
        samples: of shape (batch_size, n_sample, n_feature).
        centers: of shape (batch_size, n_center, n_feature).
        d: degree of the kernel.
        epsilon: threshold for kernel values.

    Returns:
        kernel matrix of shape (batch_size, n_sample, n_center).
    '''

    # Compute pairwise Euclidean distance across batches
    diff = samples.unsqueeze(2) - centers.unsqueeze(
        1)
    dist = torch.norm(diff, p=2, dim=-1)
    kernel_mat = dist ** d
    kernel_mat = torch.where(kernel_mat > epsilon, 1.0 / kernel_mat, 1 / epsilon)
    # Ensure no negative values
    kernel_mat.clamp_(min=0)

    return kernel_mat













