import torch

from discrete_trpl.sparse.tensor_default import DefaultSparse
from discrete_trpl.sparse.log_softmax import sparse_log_softmax


class SparseLogProb(DefaultSparse):
    """
    Sparse COO tensor wrapper for *log-probabilities* with a per-row default logit.

    Concept
    -------
    `SparseLogProb` represents a sparse tensor of logits (log-probabilities
    before normalization) where only a subset of entries along the probability
    axis are stored explicitly. All other entries along that axis are assigned
    a per-row default logit `fill_value`.

    This structure allows you to apply log-softmax and KL divergence without
    materializing the full dense probability tensor.

    Parameters
    ----------
    x : torch.Tensor
        Sparse COO tensor of logits (must satisfy `x.is_sparse`).
    fill_value : float | int | torch.Tensor
        Default logit for all implicit (unstored) entries along `dim`.
        Can be scalar or a dense tensor of shape equal to the batch shape
        (`x.shape[:dim] + x.shape[dim+1:]`).
    dim : int, default=-1
        The probability axis: dimension along which log-softmax is applied,
        and along which defaults are inserted. All other axes are treated as
        batch axes.

    Notes
    -----
    - Inherits arithmetic and densification behavior from `DefaultSparse`.
    - `log_softmax()` applies a row-wise log-softmax along `dim`, updating both
      explicit values and the per-row default.
    - `kl(to)` computes KL divergence between two `SparseLogProb` objects with
      the same sparsity pattern.
    - `densify()` yields a dense tensor of logits by filling in missing entries
      with the per-row default logit.

    This class is useful when working with very large categorical distributions
    where most outcomes share the same baseline logit.
    """
    def log_softmax(self):
        """
        Returns a new SparseProbability whose data values are rowwise (along dim) log-softmax,
        and whose default_log_prob is the per-batch log-softmax for missing entries.
        """
        y, new_d = sparse_log_softmax(self.x, self._fill_value, self.dim)
        return SparseLogProb(y, new_d, self.dim)


    def kl(self, to: "SparseLogProb") -> torch.Tensor:
        """
        Compute KL divergence KL(self || to) between two SparseLogProb objects.
        """
        from discrete_trpl.sparse.kl import sparse_kl
        return sparse_kl(log_probs1=self.x,
                         log_probs2=to.x,
                         default_log_prob=self.fill_value_tensor,
                         default_log_prob2=to.fill_value_tensor)
