import torch


def rbf_kernel(x, sigma=None, sigma_scale=1.0):
    """
    Computes the RBF (Gaussian) kernel matrix.
    """
    x_norm = (x**2).sum(dim=1).view(-1, 1)
    dist_matrix = x_norm + x_norm.t() - 2.0 * torch.mm(x, x.t())

    if sigma is None:
        sigma = torch.median(dist_matrix.detach())

    # Apply the scaling factor to sigma
    sigma = sigma * sigma_scale

    k = torch.exp(-dist_matrix / (2.0 * sigma**2))
    return k


def hsic_loss(x, y, sigma_scale=1.0):
    """
    Computes the unbiased HSIC estimate between two variables x and y.
    """
    n = x.size(0)
    device = x.device

    H = torch.eye(n, device=device) - (1 / n) * torch.ones((n, n), device=device)

    K = rbf_kernel(x, sigma_scale=sigma_scale)
    L = rbf_kernel(y, sigma_scale=sigma_scale)

    HSIC = (1 / (n - 1) ** 2) * torch.trace(K @ H @ L @ H)

    return HSIC


# Interface similar to `mutual_information_mine`
def hsic_estimate(x, y, sigma_scale=1.0):
    """
    Estimates dependence using HSIC.

    x, y: Tensors of shape (n_samples, d)
    """
    return hsic_loss(x, y, sigma_scale)


def nHSIC_estimate(x, y):
    return hsic_loss(x, y) / torch.sqrt(hsic_loss(x, x) * hsic_loss(y, y))


def rbf_kernel_2(x, sigma=None, multiscale_scales=None, eps=1e-8):
    """
    Computes one or more RBF kernels on x via torch.cdist.
    - sigma: if None, set to median pairwise distance.
    - multiscale_scales: list of multipliers to apply to sigma.
    """
    n, _ = x.shape

    # 1) get pairwise distances via cdist, then square
    #    torch.cdist is numerically stable and always >= 0
    dists = torch.cdist(x, x, p=2)  # (n,n) pairwise ||x_i - x_j||
    dist2 = dists.pow(2)  # squared

    # 2) median heuristic on *distances* if needed
    if sigma is None:
        # upper triangle i<j
        i, j = torch.triu_indices(n, n, offset=1)
        pair_dists = dists[i, j]
        sigma = torch.median(pair_dists).item() + eps

    # 3) build one or multiple kernels
    if multiscale_scales is None:
        K = torch.exp(-dist2 / (2 * sigma**2))
    else:
        Ks = []
        for s in multiscale_scales:
            sig_s = sigma * s
            Ks.append(torch.exp(-dist2 / (2 * sig_s**2)))
        K = sum(Ks)

    return K


def hsic_loss_2(x, y, multiscale_scales=[0.5, 1.0, 2.0]):
    """
    Unbiased HSIC with torch.cdist-based RBF kernels.
    """
    n, device = x.size(0), x.device

    # z-score inputs
    x = (x - x.mean(0)) / (x.std(0) + 1e-6)
    y = (y - y.mean(0)) / (y.std(0) + 1e-6)

    # compute RBF kernels via cdist
    K = rbf_kernel_2(x, sigma=None, multiscale_scales=multiscale_scales)
    L = rbf_kernel_2(y, sigma=None, multiscale_scales=multiscale_scales)

    # centering matrix
    H = torch.eye(n, device=device) - (1.0 / n) * torch.ones((n, n), device=device)

    # unbiased HSIC estimate
    return (1.0 / (n - 1) ** 2) * torch.trace(K @ H @ L @ H)
