import torch
from _operator import mul
from functools import reduce

def sparse_log_softmax(x_coo: torch.Tensor, default_log_prob, dim: int = -1):
    """
    x_coo: N-D sparse COO tensor with logits at stored indices.
    default_log_prob: scalar or broadcastable to batch_shape = x.shape[:dim]+x.shape[dim+1:].
    dim: axis along which to apply log-softmax.

    Returns:
      y_coo: sparse COO with log-softmax values at stored indices (same nnz/indices).
      new_default_log_prob: dense tensor of shape batch_shape with the log-softmax value
                            for all missing entries per batch position.
    """
    assert x_coo.layout == torch.sparse_coo
    x = x_coo.coalesce()
    sizes = x.size()
    N = len(sizes)
    dim = dim if dim >= 0 else N + dim
    assert 0 <= dim < N

    idx = x.indices()  # [N, nnz]
    vals = x.values()  # [nnz]
    device, dtype = vals.device, vals.dtype

    # Batch shape and V
    batch_sizes = sizes[:dim] + sizes[dim + 1:]
    V = sizes[dim]
    if len(batch_sizes) == 0:
        # Degenerates to 1 "batch"; keep code unified
        batch_sizes = (1,)

    # Flatten batch coordinates to a single id
    # Compute strides for linearization over batch dims
    batch_dims = [d for d in range(N) if d != dim]
    strides = []
    prod = 1
    for k in reversed(batch_dims):
        strides.append(prod)
        prod *= sizes[k]
    strides = list(reversed(strides))  # align with batch_dims order

    if idx.numel() == 0:
        # No stored entries: everything is default
        # Prepare d (broadcast to batch_sizes), then just normalize defaults
        d = torch.as_tensor(default_log_prob, device=device, dtype=dtype)
        d = d.expand(batch_sizes).contiguous()
        # m = d, sum_exp_stored = 0, n_default = V
        # lse = d + log(V)  (numerically stable form is trivial here)
        lse = d + torch.log(torch.full_like(d, V, dtype=d.dtype))
        new_default_log_prob = d - lse
        y = torch.sparse_coo_tensor(idx, vals, sizes, device=device, dtype=dtype).coalesce()
        return y, new_default_log_prob

    # Compute linear batch ids for each nnz
    batch_linear = torch.zeros(idx.size(1), device=device, dtype=torch.long)
    for k, s in zip(batch_dims, strides):
        batch_linear = batch_linear + idx[k] * s
    num_batches = reduce(mul, batch_sizes, 1)

    # Prepare per-batch default logits d[b]
    d = torch.as_tensor(default_log_prob, device=device, dtype=dtype)
    d = d.expand(batch_sizes).reshape(num_batches)

    # Counts per batch and number of defaults
    nnz_per_batch = torch.bincount(batch_linear, minlength=num_batches)
    n_default = torch.as_tensor(V, device=device) - nnz_per_batch

    # Rowwise max over stored vs default: m = max(max_stored, d)
    base = torch.full((num_batches,), float("-inf"), device=device, dtype=dtype)
    m_stored = torch.scatter_reduce(base, 0, batch_linear, vals, reduce="amax", include_self=True)
    m = torch.maximum(m_stored, d)

    # Sum exp for stored entries (shifted by m)
    exp_stored = torch.exp(vals - m[batch_linear])
    sum_exp_stored = torch.zeros(num_batches, device=device, dtype=dtype)
    sum_exp_stored.index_add_(0, batch_linear, exp_stored)

    # Sum exp for default entries
    sum_exp_default = n_default.to(dtype) * torch.exp(d - m)

    # logsumexp per batch and outputs
    lse = m + torch.log(sum_exp_stored + sum_exp_default)  # [num_batches]

    out_vals = vals - lse[batch_linear]  # [nnz]
    new_default_value = (d - lse).reshape(batch_sizes)  # batch_shape

    y = torch.sparse_coo_tensor(idx, out_vals, sizes, device=device, dtype=dtype)
    return y.coalesce(), new_default_value