import torch 
import math

def averaged_sliced_wasserstein(
    X: torch.Tensor,
    Y: torch.Tensor,
    max_dim: int,
    p: float = 1.0,
) -> torch.Tensor:
    """
    Averaged sliced Wasserstein distance over the first 1..max_dim dimensions.

    Args:
        X: Tensor of shape (B, D)
        Y: Tensor of shape (B, D)
        max_dim: maximum dimension to consider (<= D)
        p: order of Wasserstein distance (p >= 1)

    Returns:
        Scalar tensor: averaged sliced Wasserstein distance
    """
    if p < 1:
        raise ValueError("p must be >= 1")

    if X.shape != Y.shape:
        raise ValueError("X and Y must have the same shape")

    B, D = X.shape
    if max_dim > D:
        raise ValueError(f"max_dim={max_dim} exceeds data dimension D={D}")

    total = 0.0

    for d in range(1, max_dim + 1):
        # Truncate to first d dimensions: (B, d)
        Xd = X[:, :d]
        Yd = Y[:, :d]

        # Compute 1D Wasserstein per coordinate and average over coordinates
        wd = 0.0
        for j in range(d):
            xj = Xd[:, j]
            yj = Yd[:, j]

            xj_sorted, _ = torch.sort(xj)
            yj_sorted, _ = torch.sort(yj)

            diff = torch.abs(xj_sorted - yj_sorted)
            wd_j = diff.mean() if p == 1 else diff.pow(p).mean().pow(1.0 / p)
            wd = wd + wd_j

        wd = wd / d  # average over coordinates
        total = total + wd

    return total / max_dim




def log_density_Matern_mixture(X, u, C, alpha):
    """
    Compute log p(x) for the 2-component Gaussian mixture:
        alpha * N(+u, diag(C)) + (1-alpha) * N(-u, diag(C))

    Args:
        X: samples, shape (B, D)
        u: mean vector, shape (D,) or (1, D)
        C: diagonal variances, shape (D,) or (1, D)
        alpha: mixture weight in [0,1]

    Returns:
        log_probs: shape (B,)
    """
    if not (0.0 < alpha < 1.0):
        raise ValueError("alpha must be strictly between 0 and 1 for log-density")

    # Ensure shapes
    u = u.view(1, -1)
    C = C.view(1, -1)

    B, D = X.shape

    # Precompute constants
    log_det = torch.sum(torch.log(C))                      # scalar
    log_norm = -0.5 * (D * math.log(2.0 * math.pi) + log_det)

    invC = 1.0 / C

    # Quadratic forms for both components
    diff_pos = X - u
    diff_neg = X + u

    quad_pos = torch.sum(diff_pos.pow(2) * invC, dim=1)    # (B,)
    quad_neg = torch.sum(diff_neg.pow(2) * invC, dim=1)    # (B,)

    log_p_pos = log_norm - 0.5 * quad_pos + math.log(alpha)
    log_p_neg = log_norm - 0.5 * quad_neg + math.log(1.0 - alpha)

    # Stable log-sum-exp
    log_p = torch.logaddexp(log_p_pos, log_p_neg)

    return torch.mean(log_p)