import torch

def sparse_validate_distribution(log_prob_tensor: torch.Tensor,
                                 default_log_prob: float, val_eps: float = 1e-5) -> bool:
    """
    Validate that a sparse tensor represents a valid probability distribution.

    Args:
        log_prob_tensor: Sparse tensor with log probabilities
        val_eps: Tolerance for validation

    Returns:
        bool: True if valid distribution
    """
    if not log_prob_tensor.is_sparse:
        # If already dense, use standard validation
        return torch.all(torch.abs(torch.sum(log_prob_tensor.exp(), dim=-1) - 1) < val_eps)

    # For sparse tensors, we need to account for the default log prob mass
    coalesced = log_prob_tensor.coalesce()
    indices = coalesced.indices()
    values = coalesced.values()
    shape = log_prob_tensor.shape

    # Calculate sums per "row" (last dimension is vocab, preceding are batch dims)
    if len(shape) == 1:
        # 1D case
        sparse_sum = values.exp().sum()
        num_zeros = shape[0] - len(values)
    else:
        # Multi-dimensional case - group by all dimensions except last
        batch_dims = shape[:-1]
        vocab_size = shape[-1]

        # Convert multi-dim indices to flat batch indices
        if len(batch_dims) == 1:
            batch_indices = indices[0].long()
        else:
            # Flatten batch dimensions - fix the tensor indexing issue
            batch_strides = []
            for i in range(len(batch_dims)):
                stride = 1
                for j in range(i + 1, len(batch_dims)):
                    stride *= batch_dims[j]
                batch_strides.append(stride)

            batch_strides = torch.tensor(batch_strides, device=indices.device, dtype=torch.long)
            batch_indices = torch.sum(indices[:-1].long() * batch_strides.unsqueeze(1), dim=0)

        # Sum probabilities for each batch element
        total_batch_size = torch.prod(torch.tensor(batch_dims))
        prob_values = values.exp()

        # Use scatter_add to sum probabilities by batch index
        sparse_sums = torch.zeros(total_batch_size, device=values.device, dtype=values.dtype)
        sparse_sums.scatter_add_(0, batch_indices, prob_values)

        # Count non-zero elements per batch
        nnz_per_batch = torch.zeros(total_batch_size, device=values.device, dtype=torch.long)
        nnz_per_batch.scatter_add_(0, batch_indices, torch.ones_like(batch_indices, dtype=torch.long))

        # Calculate number of zero positions per batch element
        num_zeros_per_batch = vocab_size - nnz_per_batch
        sparse_sum = sparse_sums
        num_zeros = num_zeros_per_batch

    # Add mass from default positions
    default_prob = torch.exp(torch.tensor(default_log_prob, device=values.device, dtype=values.dtype))
    total_prob_mass = sparse_sum + num_zeros * default_prob

    is_valid = bool(torch.all(torch.abs(total_prob_mass - 1.0) < val_eps))
    if not is_valid:
        raise ValueError("Sparse tensor does not represent a valid probability distribution."
                         f" Given total probability mass(es): {total_prob_mass}, ")

    return is_valid
