import math
from typing import Optional

import torch

from discrete_trpl.sparse.align_tensors import sparse_align_tensors
from discrete_trpl.sparse.broadcasts import maybe_broadcast


def sparse_kl(
        log_probs1: torch.Tensor,
        log_probs2: torch.Tensor,
        default_log_prob: torch.Tensor | float,
        default_log_prob2: Optional[torch.Tensor | float]=None) -> torch.Tensor:
    """Compute KL(log_probs1 || log_probs2) for sparse log-probability tensors sharing a
    common (implicit) default value.

    This is a convenience wrapper that:
        1. Aligns two sparse COO log-probability tensors so they explicitly share the
             same set of stored coordinates (the union of their non-zero indices).
        2. Fills any coordinates missing in one tensor but present in the other with
             the provided ``default_log_prob`` (log-probability for all implicit / unstored entries).
        3. Delegates to ``_compute_sparse_kl`` to compute the KL divergence using the
             aligned explicit coordinates only. Positions that remain implicit in both
             have identical log-probabilities and contribute 0 to KL.

    Assumptions / Requirements:
        * ``log_probs1`` and ``log_probs2`` are sparse COO tensors of identical shape
            ``(..., V)`` (last dim = vocabulary / categories).
        * Each tensor only stores a subset of categories; missing ones share the same
            ``default_log_prob``.
        * ``default_log_prob`` is a scalar (float or 0-D tensor) valid for both inputs.
        * Total probability mass (explicit + implicit) per batch sums to 1 (can be
            verified via ``sparse_validate_distribution``).

    Shapes:
        * Inputs:  ``log_probs[1|2]``: sparse COO of shape ``(..., V)``
        * Output:  dense tensor of shape ``(...,)`` with per-batch KL divergences.

    Parameters
    ----------
    log_probs1 : torch.Tensor (sparse COO)
            Target distribution log-probabilities.
    log_probs2 : torch.Tensor (sparse COO)
            Reference distribution log-probabilities.
    default_log_prob : torch.Tensor or float
            Scalar log-prob used for all implicit (unstored) entries in log_probs1.
    default_log_prob2 : torch.Tensor or float
            Scalar log-prob used for all implicit (unstored) entries in log_probs2.

    Returns
    -------
    torch.Tensor
            Per-batch KL divergence ``KL(log_probs1 || log_probs2)``.
    """
    if default_log_prob2 is None:
        default_log_prob2 = default_log_prob
        add_default_log_prob_to_kl = False  # No need to add, since its the same
    else:
        add_default_log_prob_to_kl = True
    # Align sparse tensors so they have identical explicit coordinates; fill any
    # newly introduced locations with the shared default log probability.
    log_probs1_aligned, log_probs2_aligned = sparse_align_tensors(
        log_probs1,
        log_probs2,
        default_log_prob=default_log_prob,
        default_log_prob_ref=default_log_prob2
    )

    # Compute KL over the aligned explicit entries. Implicit entries cancel out.
    if add_default_log_prob_to_kl:
        return sparse_kl_aligned(log_probs1_aligned,
                                 log_probs2_aligned,
                                 default_log_prob=default_log_prob,
                                 default_log_prob_ref=default_log_prob2
                                 )
    else:
        return sparse_kl_aligned(log_probs1_aligned,
                                 log_probs2_aligned,
                                 default_log_prob=None,
                                 default_log_prob_ref=None
                                 )


def sparse_kl_aligned(target_aligned: torch.Tensor,
                      ref_aligned: torch.Tensor,
                      default_log_prob: Optional[torch.Tensor | float] = None,
                      default_log_prob_ref: Optional[torch.Tensor | float] = None

                      ) -> torch.Tensor:
    """
    Compute KL(target || ref) divergence for aligned sparse tensors.
    We assume that both tensors have the same sparsity pattern.

    Requirements:
    - `target_aligned` and `ref_aligned` are COO sparse tensors with identical dense
      shape `shape = (..., V)` and identical indices.

    Args:
        target_aligned: Sparse COO tensor of shape (..., V) with log-probabilities (same indices as ref_aligned).
        ref_aligned:    Sparse COO tensor with the same shape (same indices as target_aligned).
        default_log_prob:
            Default for missing entries in the **target** tensor. Must be either:
              * a scalar (Python number or 0-D/numel==1 tensor), or
              * a tensor of shape `shape[:-1]` (per-batch; broadcast across the last dim).
            Tensor inputs are moved to (device, dtype) and detached.
            If `None`, assumes that both tensors have the same values in the missing positions,
            so the KL contribution from those positions is zero.
        default_log_prob_ref:
            Same as `default_log_prob`, applied to the **reference** tensor.
            If `None`, assumes that both tensors have the same values in the missing positions,
            so the KL contribution from those positions is zero.


    Returns:
        KL divergence per batch element
    """
    # Assertions
    # Must provide either both log probs or none.
    assert (default_log_prob is None) == (default_log_prob_ref is None), \
        "Provide either both defaults or neither."

    # Both tensors have the same sparsity pattern
    target_indices = target_aligned.indices()
    target_values = target_aligned.values()
    ref_values = ref_aligned.values()  # Same indices as target_aligned

    shape = target_aligned.shape
    batch_dims = shape[:-1]
    logits_size = shape[-1]

    # Convert multi-dim indices to flat batch indices
    if len(batch_dims) == 1:
        batch_indices = target_indices[0].long()
    else:
        # Compute batch strides once
        batch_strides = []
        for i in range(len(batch_dims)):
            stride = 1
            for j in range(i + 1, len(batch_dims)):
                stride *= batch_dims[j]
            batch_strides.append(stride)

        batch_strides = torch.tensor(batch_strides, device=target_indices.device, dtype=torch.long)
        batch_indices = torch.sum(target_indices[:-1].long() * batch_strides.unsqueeze(1), dim=0)

    # Calculate total batch size
    total_batch_size = math.prod(batch_dims)

    # Compute KL contributions for aligned indices
    # KL divergence: p * log(p/q) = p * (log_p - log_q)
    target_probs = target_values.exp()
    kl_contributions = target_probs * (target_values - ref_values)

    # Sum KL contributions by batch index
    kl_divergences = torch.zeros(total_batch_size, device=target_aligned.device, dtype=target_aligned.dtype)
    kl_divergences.scatter_add_(0, batch_indices, kl_contributions)

    # Add contribution from positions not in union.
    # If no defaults provided: return the union KL only, since those positions cancel out and contribute 0.
    if default_log_prob is not None and default_log_prob_ref is not None:
        # Otherwise, add the KL from coordinates outside the union:
        # Each zero entry has the same contribution (dt.exp() * (dt - dr), and we have
        # missing_per_batch = logit_size - nnz_per_batch of them, so we can simply add that to the previous KL
        nnz_per_batch = (torch.bincount(batch_indices, minlength=total_batch_size)
                         if batch_indices.numel()
                         else torch.zeros(total_batch_size, device=target_aligned.device, dtype=torch.long))
        missing_per_batch = (logits_size - nnz_per_batch).to(dtype=kl_divergences.dtype)

        # Use maybe_broadcast to get per-batch defaults of shape `batch_dims`, then flatten
        dt = maybe_broadcast(default_log_prob, batch_dims, kl_divergences.device, kl_divergences.dtype).reshape(-1)
        dr = maybe_broadcast(default_log_prob_ref, batch_dims, kl_divergences.device, kl_divergences.dtype).reshape(-1)

        kl_divergences = kl_divergences + (dt.exp() * (dt - dr)) * missing_per_batch

    if len(batch_dims) == 1:
        return kl_divergences
    else:
        return kl_divergences.view(batch_dims)
