from typing import Tuple, Union, Optional

import torch

def _defaults_to_union(
    value: Union[float, int, torch.Tensor],
    shape: Tuple[int, ...],          # full dense shape, e.g. (B1, ..., Bk, V)
    union_idx: torch.Tensor,         # [ndim, U] long; union coordinates
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Convert user-specified defaults into a length-U vector aligned with the union of
    sparse coordinates.

    What it does:
        Given a dense tensor shape `shape = (B1, ..., Bk, V)` and the union index
        set `union_idx ∈ ℕ^{ndim×U}` (coalesced COO coordinates of the union of
        nonzeros), produce a 1-D tensor `d ∈ ℝ^U` of default values on (device, dtype)
        to fill entries that are missing in either input sparse tensor.

    Supported inputs:
        - Scalar (Python number or 0-D/numel==1 tensor): expanded to all U positions.
        - Tensor of shape `shape[:-1]` (per-batch defaults): gathered by mapping the
          first k coordinates of each union entry to a linear batch index; the last
          (logit) dimension is ignored for this gather.

    Notes:
        - The returned defaults are detached (treated as constants; no gradients flow).
        - If U == 0, returns an empty (0,) tensor on (device, dtype).
        - Raises ValueError if `value` has any other shape.

    Args:
        value: Scalar or tensor of shape `shape[:-1]` providing per-batch defaults.
        shape: Full dense shape (..., V) shared by both sparse tensors.
        union_idx: Long tensor of shape [ndim, U] with union coordinates.
        device: Target device for the output.
        dtype:  Target dtype for the output.

    Returns:
        A contiguous tensor of shape (U,) with defaults aligned to `union_idx`.

    Raises:
        ValueError: If `value` is a tensor and its shape is neither `()` nor `shape[:-1]`.
    """
    U = union_idx.size(1)
    if U == 0:
        return torch.empty((0,), device=device, dtype=dtype)

    # scalar path
    if not isinstance(value, torch.Tensor):
        return torch.full((U,), value, device=device, dtype=dtype)

    v = value.to(device=device, dtype=dtype).detach()

    # scalar-like tensor
    if v.ndim == 0 or v.numel() == 1:
        return v.reshape(()).expand((U,)).clone()

    # per-batch tensor with shape[:-1]
    batch_shape = shape[:-1]
    if tuple(v.shape) == tuple(batch_shape):
        if len(batch_shape) == 0:
            # degenerate case: no batch dims -> treat as scalar
            return v.reshape(()).expand((U,)).clone()
        v_flat = v.reshape(-1).contiguous()

        # compute linear batch indices for union coords (exclude last dim)
        strides = [1] * len(batch_shape)
        for d in range(len(batch_shape) - 2, -1, -1):
            strides[d] = strides[d + 1] * batch_shape[d + 1]
        strides_t = torch.tensor(strides, device=union_idx.device, dtype=union_idx.dtype)

        lin = (union_idx[:len(batch_shape)].to(strides_t.dtype) * strides_t.unsqueeze(1)).sum(0).long()
        return v_flat.index_select(0, lin)

    raise ValueError(
        f"Unsupported default shape {tuple(v.shape)}. "
        f"Provide a scalar or a tensor of shape {batch_shape}."
    )

def sparse_align_tensors(
    log_target_prob: torch.Tensor,
    log_ref_prob: torch.Tensor,
    default_log_prob: Union[torch.Tensor, float, int],
    default_log_prob_ref: Optional[Union[torch.Tensor, float, int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Align two COO sparse log-probability tensors so they share the same sparsity pattern
    (the union of their nonzero coordinates across all batch dims and the last/logit dim).
    Existing nonzeros are preserved; indices missing in one tensor but present in the other
    are filled with provided defaults.

    Gradients flow only through original nonzero values; inserted defaults are treated
    as constants.

    Requirements:
    - `log_target_prob` and `log_ref_prob` are COO sparse tensors with identical dense
      shape `shape = (..., V)`. Outputs are coalesced COO tensors on the input device/dtype.

    Args:
        log_target_prob: Sparse COO tensor of shape (..., V) with log-probabilities.
        log_ref_prob:    Sparse COO tensor with the same shape.
        default_log_prob:
            Default for missing entries in the **target** tensor. Must be either:
              * a scalar (Python number or 0-D/numel==1 tensor), or
              * a tensor of shape `shape[:-1]` (per-batch; broadcast across the last dim).
            Tensor inputs are moved to (device, dtype) and detached.
        default_log_prob_ref:
            Same as `default_log_prob`, applied to the **reference** tensor. If `None`,
            uses `default_log_prob`.

    Returns:
        (target_aligned, ref_aligned): two COO sparse tensors with identical indices equal
        to the union of active coordinates of the inputs. For each tensor, values are the
        original nonzeros where present, otherwise the corresponding default.
    """
    if default_log_prob_ref is None:
        default_log_prob_ref = default_log_prob

    assert log_target_prob.is_sparse and log_ref_prob.is_sparse, "Inputs must be sparse COO"

    t = log_target_prob.coalesce()
    r = log_ref_prob.coalesce()
    assert t.shape == r.shape, "Both tensors must have identical shapes"

    shape = t.shape
    t_idx, t_val = t.indices(), t.values()
    r_idx, r_val = r.indices(), r.values()
    device, dtype = t_val.device, t_val.dtype

    # Union of coordinates
    both_idx = torch.cat([t_idx, r_idx], dim=1)                         # [ndim, T+R]
    uniq, inv = torch.unique(both_idx.T, dim=0, return_inverse=True)    # uniq: [U, ndim]
    union_idx = uniq.T.contiguous()
    T = t_idx.size(1)

    # Positions inside the union for target/ref entries
    t_pos = inv[:T]
    r_pos = inv[T:]

    # Build default vectors for target and ref
    default_vec_target = _defaults_to_union(default_log_prob,     shape, union_idx, device, dtype)
    default_vec_ref    = _defaults_to_union(default_log_prob_ref, shape, union_idx, device, dtype)

    # Scatter original values over defaults (keeps grad on originals only)
    t_vals_aligned = default_vec_target.clone()
    r_vals_aligned = default_vec_ref.clone()
    if t_pos.numel():
        t_vals_aligned.scatter_(0, t_pos, t_val)
    if r_pos.numel():
        r_vals_aligned.scatter_(0, r_pos, r_val)

    target_aligned = torch.sparse_coo_tensor(union_idx, t_vals_aligned, shape, device=device, dtype=dtype).coalesce()
    ref_aligned    = torch.sparse_coo_tensor(union_idx, r_vals_aligned, shape, device=device, dtype=dtype).coalesce()
    return target_aligned, ref_aligned