import warnings

import torch


def normalize_samples(x: torch.Tensor) -> torch.Tensor:
    """
    Normalize the samples in x to have zero mean and unit variance for each feature.
    This is useful if we have heteroscedastic noise in the two samples.

    Args:
        x (torch.Tensor): A tensor of shape (n_samples, n_features)

    Returns:
        torch.Tensor: Normalized tensor of the same shape.
    """
    mean = x.mean(dim=0, keepdim=True)
    std = x.std(dim=0, unbiased=True, keepdim=True)
    # To avoid division by zero, add a small constant to std
    std = std + 1e-6
    return (x - mean) / std


def gaussian_kernel(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
    """
    Compute the Gaussian (RBF) kernel between two sets of samples.

    Args:
        x (torch.Tensor): A tensor of shape (n_samples_x, n_features)
        y (torch.Tensor): A tensor of shape (n_samples_y, n_features)
        sigma (float): Bandwidth parameter for the RBF kernel.

    Returns:
        torch.Tensor: Kernel matrix of shape (n_samples_x, n_samples_y)
    """
    # Compute the squared Euclidean distance between each pair of points.
    # x: (n, d), y: (m, d)
    # Using the formula: ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x*y^T
    x_norm = (x**2).sum(dim=1).view(-1, 1)  # (n, 1)
    y_norm = (y**2).sum(dim=1).view(1, -1)  # (1, m)
    distances = x_norm + y_norm - 2.0 * torch.mm(x, y.t())

    # Compute the Gaussian kernel matrix
    kernel = torch.exp(
        -distances / (2 * sigma**2)
    )  # the smaller sigma, the fast the kernel decays (emphasizing the local structures)
    # TODO: check if we should have an amplitude parameter - this would be useful if we have heteroscedastic noise
    # this might be the case depending on how we create the inverse samples
    return kernel


def compute_mmd(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0, normalize: bool = False) -> float:
    """
    Compute the unbiased estimate of the squared MMD statistic between two samples.

    Args:
        x (torch.Tensor): Samples from distribution P, shape (n, d)
        y (torch.Tensor): Samples from distribution Q, shape (m, d)
        sigma (float): Bandwidth parameter for the RBF kernel.
        normalize (bool): Whether to pre-normalize each group to have zero mean and unit variance.

    Returns:
        float: The computed MMD^2 statistic.
    """
    if normalize:
        x = normalize_samples(x)
        y = normalize_samples(y)

    n = x.shape[0]
    m = y.shape[0]

    # Compute kernel matrices
    K_xx = gaussian_kernel(x, x, sigma)
    K_yy = gaussian_kernel(y, y, sigma)
    K_xy = gaussian_kernel(x, y, sigma)

    # Create indices to remove the diagonal (for unbiased estimate)
    # For K_xx and K_yy, we remove the self-similarity terms.
    if n == 1:
        warnings.warn("Using biased estimator for group X because sample size is 1. Ideally use more data.")
        sum_K_xx = K_xx.mean()
    else:
        sum_K_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))

    if m == 1:
        warnings.warn("Using biased estimator for group Y because sample size is 1. Ideally use more data.")
        sum_K_yy = K_yy.mean()
    else:
        sum_K_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))

    sum_K_xy = K_xy.mean()  # (n*m terms)

    mmd2 = sum_K_xx + sum_K_yy - 2 * sum_K_xy
    return mmd2.item()


def mmd_test_rbf(
    x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0, normalize: bool = False, num_perms: int = 1000
) -> float:
    """
    Perform a permutation test to calculate the p-value for the MMD statistic.

    Args:
        x (torch.Tensor): Samples from distribution P, shape (n, d)
        y (torch.Tensor): Samples from distribution Q, shape (m, d)
        sigma (float): Bandwidth parameter for the RBF kernel.
        normalize (bool): Whether to pre-normalize each group.
        num_perms (int): Number of permutations to perform.

    Returns:
        float: The p-value computed from the permutation test.
    """
    # Concatenate samples
    Z = torch.cat((x, y), dim=0)
    n = x.shape[0]
    total = Z.shape[0]

    # Compute the observed MMD statistic
    observed_stat = compute_mmd(x, y, sigma, normalize)

    perm_statistics = torch.zeros(num_perms)
    for i in range(num_perms):
        perm = torch.randperm(total)
        x_perm = Z[perm[:n]]
        y_perm = Z[perm[n:]]
        perm_statistics[i] = compute_mmd(x_perm, y_perm, sigma, normalize)

    # Compute the p-value: fraction of permuted stats greater than the observed stat.
    p_val = (perm_statistics > observed_stat).sum().item() / num_perms
    return p_val


# Example usage:
if __name__ == "__main__":
    # Generate random high-dimensional data for demonstration:
    # For example, 1000 samples from group X and 1 sample from group Y (or more generally a small sample)
    d = 100  # dimensionality
    x = torch.randn(1000, d)  # 1000 samples from distribution P
    y = torch.randn(1, d)  # 1 sample from distribution Q

    sigma = 1.0  # Set the kernel bandwidth; this may require tuning

    # Set normalize=True to pre-normalize each group in case of heteroscedasticity
    mmd2_stat = compute_mmd(x, y, sigma, normalize=True)
    p_val = mmd_test_rbf(x, y, sigma, normalize=True, num_perms=1000)

    print("Squared MMD statistic:", mmd2_stat)
    print("p-value from permutation test:", p_val)
