
import torch
from typing import Union

from discrete_trpl.sparse.broadcasts import maybe_broadcast

def densify(
        x: torch.Tensor,
        fill_value: Union[float, int, torch.Tensor],
        dim: int = -1,
) -> torch.Tensor:
    """
    Convert a sparse COO tensor to dense, filling only *implicit* (unstored) entries
    along axis `dim` with `fill_value`.

    - Preserves explicit entries exactly (including explicit zeros).
    - Calls `coalesce()` to sum duplicates before writing.
    - `fill_value` may be a scalar or a tensor with shape equal to the *batch shape*
      `x.shape[:dim] + x.shape[dim+1:]`.

    Returns a dense tensor with the same shape/device/dtype as `x`.
    """
    if not x.is_sparse:
        return x

    x = x.coalesce()
    ndim = x.dim()
    dim = dim if dim >= 0 else ndim + dim
    if not (0 <= dim < ndim):
        raise ValueError(f"`dim` out of range for tensor of dim {ndim}: got {dim}")

    idx = x.indices()
    vals = x.values()
    out_shape = tuple(x.shape)

    # Prepare base filled with per-batch defaults; insert singleton at `dim`.
    batch_shape = out_shape[:dim] + out_shape[dim + 1:]
    batch_fill = maybe_broadcast(
        fill_value, batch_shape, device=vals.device, dtype=vals.dtype,
    )
    base = batch_fill.unsqueeze(dim)  # shape == out_shape with size 1 at dim
    out = base.expand(out_shape)  # materialize for assignment

    if vals.numel() > 0:
        out = out.index_put(tuple(idx), vals)
    return out
