from collections import defaultdict
import math
import warnings
from typing import Tuple, Dict, Optional

import torch


def sparsify_normalized_logits(norm_logits: torch.Tensor,
                               selected_token_mask: Optional[torch.Tensor] = None,
                               threshold: float = 1e-5,
                               total_default_mass: bool = False,
                               total_default_keep_maxnum: int = 64,
                               default: float = 1e-8,
                               return_sparse: bool = True) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    Sparsifies a batch of log-probability distributions by setting values below a threshold
    to a default value while maintaining proper normalization using sparse tensor operations.

    This function is useful for creating sparse representations of probability distributions
    where very small probabilities are replaced with a default value, while ensuring the
    remaining probabilities are rescaled to maintain normalization.

    Args:
        norm_logits (torch.Tensor): Either
            * A dense tensor of shape (batch_size, seq_len, vocab_size), or
            * A sparse tensor of shape (total_tokens, vocab_size)
            containing log-probabilities that sum to 1 along the last dimension when exponentiated.
        threshold (float, optional): Probability threshold below which values are set to default.
            Values where exp(norm_logits) < threshold will be zeroed out. Defaults to 1e-5.
            If total_default_mass == True, then threshold defines the total mass of all default logits.

        default (float, optional): Default probability value for elements that were zeroed out.
            Defaults to 1e-8.
        return_sparse (bool, optional): If True, returns a sparse COO tensor. If False, returns
            a dense tensor with zeros replaced by log(default). Defaults to True.

    Returns:
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        1.
            torch.Tensor: Either a sparse COO tensor (if return_sparse=True) or dense tensor
                (if return_sparse=False) representing the renormalized log-probabilities.
                When converted to probabilities, each (batch, sequence) slice sums to 1.
        2.
            Dict[str, torch.Tensor]: Information dictionary containing:
                - "remaining_mass": Tensor of shape (batch_size, seq_len) or (total_tokens,)
                with the log of the mass retained after thresholding. This is how much is *not* thrown away.

    Example:
        import torch
        # Create some normalized logits (log-softmax output)
        logits = torch.randn((2, 3, 100))
        norm_logits = torch.log_softmax(logits, dim=-1)

        # Renormalize with threshold and default in probability space
        sparse_result = sparsify_normalized_logits(
        ...     norm_logits,
        ...     threshold=1e-4,
        ...     default=1e-6
        ... )

        # Convert back to dense and check normalization
        dense_result = sparse_result.to_dense()
        dense_result[dense_result == 0] = math.log(1e-6)
        probs = dense_result.exp()
        print(torch.allclose(probs.sum(dim=-1), torch.ones(2, 3)))  # Should be True

        # You can also get dense output directly
        dense_result = sparsify_normalized_logits(
        ...     norm_logits,
        ...     threshold=1e-4,
        ...     default=1e-6,
        ...     return_sparse=False
        ... )
        probs = dense_result.exp()
        print(torch.allclose(probs.sum(dim=-1), torch.ones(2, 3)))  # Should be True
    """
    if total_default_mass:
        return _sparsify_normalized_logits_mass(
        norm_logits=norm_logits,
        selected_token_mask=selected_token_mask,
        threshold=threshold,
        total_default_keep_maxnum=total_default_keep_maxnum,
        default=default,
        return_sparse=return_sparse
    )
    else:
        return _sparsify_normalized_logits(
            norm_logits=norm_logits,
            selected_token_mask=selected_token_mask,
            threshold=threshold,
            default=default,
            return_sparse=return_sparse
        )


def sparsify_and_pad_response_logits(
        logits_rmpad: torch.Tensor,
        indices: torch.Tensor,
        batch_size: int,
        seqlen: int,
        response_length: int,
        selected_token_mask: Optional[torch.Tensor] = None,
        threshold: float = 1e-5,
        total_default_mass: bool = False,
        total_default_keep_maxnum: int = 64,
        default: float = 1e-8,
        chunk_size: int =-1,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    Efficiently sparsify response logits without first restoring full padding.

    This function performs the following operations in an optimized order:
    1. Extract response tokens from flattened tensor
    2. Apply log_softmax normalization
    3. Sparsify the normalized logits along vocab dimension
    4. Transform indices to match padded output format

    Args:
        logits_rmpad: Flattened logits tensor of shape (total_nnz, vocab_size)
        indices: Indices tensor from unpad_input
        batch_size: Original batch size
        seqlen: Original sequence length
        response_length: Length of response tokens
        selected_token_mask: Boolean mask of shape (total_nnz,) indicating which logits are always selected. May be None
        threshold: Sparsification threshold in probability space. If total_default_mass == True, then threshold for probability sum of sparsified entries
        total_default_mass: Don't threshold every logit independently but select the smallest amount
            of large logits such that remaining probability of the sparsified logits is less than threshold.
        total_default_keep_maxnum: The maximum number of top probability logits to keep, even if more mass is sparsified.
        default: Default probability for sparsified entries
        chunk_size: If > 0, process response tokens in chunks of this size to reduce memory usage.

    Returns:
        1.
            torch.Tensor: Sparse tensor of shape (batch_size, response_length, vocab_size) or
            shape (num_tokens, vocab_size) if flattened.
            Contains the sparsified and renormalized logits for response tokens.
        2.
            Dict[str, torch.Tensor]: Information dictionary from sparsification process.
    """
    assert logits_rmpad.ndim == 2, "Efficient sparsify assumes flattened logits with single token dim"
    device = logits_rmpad.device
    vocab_size = logits_rmpad.shape[1]
    dtype = torch.float32 # Absolutely crucial upcasting. Don't you dare change this.
    # Having this bf16 makes the probabilities ill-posed and leads to crashes.
    # Delay upcast after response_mask filter to reduce memory usage, however.


    # Step 1: Extract response tokens from flattened tensor
    seq_indices = indices % seqlen
    response_start = seqlen - response_length - 1
    response_end = seqlen - 1
    response_mask = (seq_indices >= response_start) & (seq_indices < response_end)

    if not response_mask.any():
        # No response tokens found, return empty sparse tensor
        return torch.sparse_coo_tensor(
            torch.zeros((3, 0), dtype=torch.int64, device=device),
            torch.zeros(0, dtype=dtype, device=device),
            (batch_size, response_length, vocab_size),
            device=device,
        ).coalesce()

    # filter by response mask
    response_indices = indices[response_mask]
    if selected_token_mask is not None:
        selected_token_mask = selected_token_mask[response_mask]

    # Step 2: Apply log_softmax normalization
    if chunk_size <= 0:
        # Critical upcast here before log_softmax.
        response_logits_rmpad = logits_rmpad[response_mask].to(dtype)
        normalized_response_logits_rmpad = response_logits_rmpad.log_softmax(dim=1).to(dtype=dtype)

        # Step 3: Sparsify the normalized logits
        sparse_logits_rmpad, info_dict = sparsify_normalized_logits(normalized_response_logits_rmpad,
                                                                    selected_token_mask=selected_token_mask,
                                                                    threshold=threshold,
                                                                    default=default,
                                                                    total_default_mass=total_default_mass,
                                                                    total_default_keep_maxnum=total_default_keep_maxnum,
                                                                    return_sparse=True)
    else:
        # mask response but don't upcast everything as we chunk-by-chunk possible
        response_logits_rmpad = logits_rmpad[response_mask]
        indices = []
        values = []
        prior_chunks_nnzs = 0
        info_dicts = []

        # Process chunks
        num_response_tokens = response_logits_rmpad.shape[0]
        for start in range(0, num_response_tokens, chunk_size):
            end = min(start + chunk_size, num_response_tokens)
            # View the chunk without copying and perform
            # critical upcast here (only for chunk to limit memory)
            chunk = response_logits_rmpad[start:end].to(dtype=dtype)  # shape: (chunk_size, logits)
            # compute log_softmax on the chunk along last dim (which are the logits)
            # Same dim works for chunk; since we chunked (with nonzero size) along 0, dim remains same.
            normalized_chunk = torch.nn.functional.log_softmax(chunk, dim=-1)

            sparse_logits_rmpad_chunk, chunk_info_dict = sparsify_normalized_logits(normalized_chunk,
                                                                selected_token_mask=selected_token_mask[start:end],
                                                                threshold=threshold,
                                                                default=default,
                                                                total_default_mass=total_default_mass,
                                                                total_default_keep_maxnum=total_default_keep_maxnum,
                                                                return_sparse=True)
            # Make sure that sparse_logits_rmpad_chunk is coalesced, for free if it already is
            sparse_logits_rmpad_chunk = sparse_logits_rmpad_chunk.coalesce()

            # update token index to account for prior chunks, keep logit index the same
            chunk_indices = sparse_logits_rmpad_chunk.indices()
            # create [2, 1] tensor with value {prior_chunks_nnzs, 0}
            index_offset = torch.linspace(start=prior_chunks_nnzs, end=0, steps=2, device=device, dtype=torch.int64).unsqueeze_(-1)
            # to perform not in-place version of chunk_indices[0, :] += prior_chunks_nnzs
            # in-place would cause broken grad
            chunk_indices = chunk_indices + index_offset
            indices.append(chunk_indices)
            # Keep values to concatenate later on
            values.append(sparse_logits_rmpad_chunk.values())
            # add chunk_size (or final chunk size) to offset for next chunk indices
            prior_chunks_nnzs += sparse_logits_rmpad_chunk.shape[0]

            # keep all the info dicts
            info_dicts.append(chunk_info_dict)

        # concat the [2, nnz_i] chunk tensors into one larger [2, total_nnz]
        idx = torch.cat(indices, dim=1)
        vals = torch.cat(values, dim=0) # [total_nnz, ...]
        # Create coalesced sparse tensor from all the chunks, prior_chunks_nnzs finishes with
        # number of response tokens
        assert prior_chunks_nnzs == num_response_tokens
        sparse_logits_rmpad = torch.sparse_coo_tensor(idx, vals, (prior_chunks_nnzs, vocab_size), dtype=dtype, device=device)
        sparse_logits_rmpad = sparse_logits_rmpad.coalesce()

        # Combine the chunks' info dicts into one info dict by concatenate each value
        # along it's first dimension
        grouped = defaultdict(list)
        for d in info_dicts:
            for k, v in d.items():
                grouped[k].append(v)
        # Since we are working with flattened 2d logits, the info values are also flat.
        info_dict = {k: torch.cat(vs, dim=0) for k, vs in grouped.items()}

    # Step 4: Transform indices to match padded output format (batch_size, response_length, vocab_size)
    if sparse_logits_rmpad._nnz() == 0:
        # No non-zero elements after sparsification
        return torch.sparse_coo_tensor(
            torch.zeros((3, 0), dtype=torch.int64, device=device),
            torch.zeros(0, dtype=dtype, device=device),
            (batch_size, response_length, vocab_size),
            device=device,
        ).coalesce(), info_dict

    # Get sparse indices and values
    sparse_indices = sparse_logits_rmpad.indices()  # (2, nnz) for 2D sparse tensor
    sparse_values = sparse_logits_rmpad.values()  # (nnz,)

    # Map sparse indices back to original position indices
    rmpad_row_indices = sparse_indices[0]  # Which row in the rmpad tensor
    vocab_indices = sparse_indices[1]  # Which vocabulary index

    # Get the original indices for these positions
    original_indices = response_indices[rmpad_row_indices]

    # Convert to (batch, seq) coordinates
    batch_indices = original_indices // seqlen
    seq_indices_full = original_indices % seqlen

    # Convert sequence indices to response-relative indices
    response_seq_indices = seq_indices_full - response_start

    # Create new 3D sparse tensor indices (batch, response_seq, vocab)
    new_indices = torch.stack([batch_indices, response_seq_indices, vocab_indices], dim=0)

    # Create sparse tensor with correct shape
    sparse_result = torch.sparse_coo_tensor(
        new_indices, sparse_values, (batch_size, response_length, vocab_size), device=device, dtype=dtype
    ).coalesce()

    return sparse_result, info_dict



def _fix_empty_rows_top1(norm_logits: torch.Tensor, sparse_logits: torch.Tensor):
    """
    If any row has zero nnz in `sparse_logits`, insert the argmax token from
    `norm_logits` for that row. Returns (fixed_sparse, did_fix: bool).
    Works for 2D (N, V) and 3D (B, S, V).
    """
    assert sparse_logits.is_sparse, "sparse_logits must be a sparse COO tensor."
    norm_logits_shape = norm_logits.shape
    device = norm_logits.device
    dtype = norm_logits.dtype

    indices = sparse_logits.indices()
    values = sparse_logits.values().to(dtype)

    if len(norm_logits_shape) == 3:
        batch_size, seq_len, vocab_size = norm_logits_shape
        total_tokens = batch_size * seq_len
        batch_seq_indices = indices[0] * seq_len + indices[1]
        non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens)
    elif len(norm_logits_shape) == 2:
        total_tokens, vocab_size = norm_logits_shape
        batch_seq_indices = indices[0]
        non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens)
    else:
        raise ValueError("norm_logits must be 2D or 3D (…, V).")

    empty_rows = torch.nonzero(non_zeros == 0, as_tuple=False).squeeze(1)
    if empty_rows.numel() == 0:
        return sparse_logits, False  # nothing to fix

    # Argmax on original (pre-threshold) logits
    flat_logits = norm_logits.view(total_tokens, vocab_size)
    argmax_vocab_indices = torch.argmax(flat_logits[empty_rows], dim=-1)
    new_values = flat_logits[empty_rows, argmax_vocab_indices]

    if len(norm_logits_shape) == 3:
        batch_indices = empty_rows // seq_len
        seq_indices = empty_rows % seq_len
        vocab_indices = argmax_vocab_indices
        add_indices = torch.stack([batch_indices, seq_indices, vocab_indices], dim=0)
    else:
        add_indices = torch.stack([empty_rows, argmax_vocab_indices], dim=0)

    # Append new indices and values, then coalesce
    new_indices = torch.cat([indices, add_indices.to(indices.dtype)], dim=1)
    new_values_cat = torch.cat([values, new_values.to(dtype)], dim=0)

    fixed_sparse = torch.sparse_coo_tensor(new_indices, new_values_cat, norm_logits_shape, device=device,
                                           dtype=dtype).coalesce()
    return fixed_sparse


def _sparsify_normalized_logits(norm_logits: torch.Tensor,
                               selected_token_mask: Optional[torch.Tensor] = None,
                               threshold: float = 1e-5,
                               default: float = 1e-8,
                               return_sparse: bool = True) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    if threshold <= 0.0 or default <= 0.0 or default >= 1.0:
        raise ValueError("Require 0 < default < 1 and threshold > 0.")

    # Convert thresholds from probability space to log space
    log_threshold = math.log(threshold)
    log_default = math.log(default)

    device = norm_logits.device
    dtype = norm_logits.dtype

    vocab_size = norm_logits.shape[-1]  # Logit dimension

    # Build mask of entries above threshold
    with torch.no_grad():  # No need to track this in autograd
        mask = norm_logits > log_threshold  # bool tensor, same shape

        if selected_token_mask is not None:
            if selected_token_mask.shape != mask.shape[:-1]:
                raise ValueError("selected_token_mask must have shape matching norm_logits except for last dim.")
            # Always keep the selected logits for each token
            # token_positions = torch.arange(selected_token_mask.size(0), device=selected_token_mask.device)
            # mask[token_positions, selected_token_mask] = True
            mask.scatter_(-1, selected_token_mask.unsqueeze(-1), True)

    indices = mask.nonzero(as_tuple=False).T  # [ndim, nnz], on GPU
    values = norm_logits[mask]  # [nnz], picked values
    values = values.clamp(max=-1e-12)  # ensure no exact zeros

    with torch.no_grad():  # No need to track this in autograd
        norm_logits_shape = norm_logits.shape
        if len(norm_logits_shape) == 3:  # (B, S, V)
            batch_size, seq_len, vocab_size = norm_logits_shape
            total_tokens = batch_size * seq_len
            batch_seq_indices = indices[0] * seq_len + indices[1]
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)
            non_zeros = non_zeros.view(batch_size, seq_len)  # [B, S]
        else:  # (N, V)
            # Support the (potentially sparse) 2D case
            batch_size, seq_len = None, None  # Flattened, so not used
            total_tokens, vocab_size = norm_logits_shape
            batch_seq_indices = indices[0]
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)

    if torch.any(non_zeros == 0):
        warnings.warn(
            "Some rows have zero non-zeros after thresholding. These will be filled with the argmax token, this is potentially inefficient")
        # Build sparse COO tensor directly
        sparse_logits = torch.sparse_coo_tensor(
            indices.to(torch.int32),  # match your dtype downstream
            values,
            norm_logits.shape,
            device=norm_logits.device,
            dtype=norm_logits.dtype
        ).coalesce()
        sparse_logits = _fix_empty_rows_top1(norm_logits, sparse_logits)
        indices = sparse_logits.indices()
        values = sparse_logits.values()

        # Recalculate batch_seq_indices after fixing empty rows
        if len(norm_logits_shape) == 3:  # (B, S, V)
            batch_seq_indices = indices[0] * seq_len + indices[1]
        else:  # (N, V)
            batch_seq_indices = indices[0]

        # Recalculate non_zeros after fixing empty rows
        if len(norm_logits_shape) == 3:  # (B, S, V)
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)
            non_zeros = non_zeros.view(batch_size, seq_len)  # [B, S]
        else:  # (N, V)
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)
        assert torch.all(non_zeros > 0), "There should be no empty rows after fixing"

    zeros_mass = (vocab_size - non_zeros) * default

    # Compute logsumexp for each (batch, seq) pair using scatter operations
    # We use the numerically stable logsumexp: log(sum(exp(x))) = max + log(sum(exp(x - max)))
    max_vals = torch.full((total_tokens,), -float("inf"), device=device, dtype=dtype)
    max_vals = max_vals.scatter_reduce(0, batch_seq_indices, values, reduce="amax")

    # Compute exp(x - max) and sum
    shifted_values = values - max_vals[batch_seq_indices]
    exp_shifted = torch.exp(shifted_values)

    sum_exp = torch.zeros(total_tokens, device=device, dtype=dtype)
    sum_exp = sum_exp.scatter_add(0, batch_seq_indices, exp_shifted)

    # Final logsumexp result (log of remaining mass)
    log_remaining_mass = max_vals + torch.log(sum_exp)
    log_remaining_mass = log_remaining_mass.view(*norm_logits_shape[:-1])  # [batch_size, seq_len] or [nnz,]

    # Compute target mass (1 - zeros_mass) in log space
    log_target_mass = torch.log(1 - zeros_mass)

    # Update sparse values with normalization
    # new_logit = old_logit + log(target_mass) - log(remaining_mass)
    normalized_values = values + log_target_mass.view(-1)[batch_seq_indices] - log_remaining_mass.view(-1)[
        batch_seq_indices]

    # Create normalized sparse tensor
    normalized_sparse = torch.sparse_coo_tensor(
        indices.to(torch.int32).contiguous(), normalized_values.contiguous(), norm_logits_shape, device=device,
        dtype=dtype
    ).coalesce()

    with torch.no_grad():
        # nnz entries
        info_dict = {"remaining_mass": log_remaining_mass.exp(),
                     "num_selected_tokens": non_zeros.float(),
                     }

    if return_sparse:
        return normalized_sparse, info_dict
    else:
        # Convert to dense and set zeros to log(default)
        dense_result = normalized_sparse.to_dense()
        dense_result[dense_result == 0] = log_default
        return dense_result, info_dict



def _sparsify_normalized_logits_mass(norm_logits: torch.Tensor,
                               selected_token_mask: Optional[torch.Tensor] = None,
                               threshold: float = 1e-5,
                               total_default_keep_maxnum: int = 64,
                               default: float = 1e-8,
                               return_sparse: bool = True) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    if threshold <= 0.0 or default <= 0.0 or default >= 1.0:
        raise ValueError("Require 0 < default < 1 and threshold > 0.")

    # Convert thresholds from probability space to log space
    log_threshold = math.log(threshold)
    log_default = math.log(default)

    device = norm_logits.device
    dtype = norm_logits.dtype

    vocab_size = norm_logits.shape[-1]  # Logit dimension

    # TODO: Can get rid of mask here if we need efficiency.

    # Cool new code
    # Build mask of entries above threshold
    with torch.no_grad():  # No need to track this in autograd
        mask = torch.zeros(norm_logits.shape, device=device, dtype=torch.bool)
        # Assume very positive skew (i.e. few large prob logits dominating) and try topk to find threshold
        # In worst case, we'll drop some mass if model is quite uncertain/high entropy with many low-prob options
        topk_logits, topk_indices = torch.topk(norm_logits, k=total_default_keep_maxnum, dim=-1, largest=True,
                                               sorted=True)
        topk_totalprob = topk_logits.exp().cumsum(
            dim=-1)  # normalized and close to zero logit, so no lse-shift trick needed
        # If only the probability sums of the top n(<topk) logits are all smaller than desired total prob mass,
        num_top_totalprob_too_small = (topk_totalprob + threshold < 1.0).sum(dim=-1)
        # then picking first n+1 logits barely has desired total mass.
        # Indexing example: summing first 1, 2, or 3 is too small, then topk_logits[3] is fourth largest logit which just
        # needs to be included and topk_indices[0:4] are the indices in mask.
        # If all topk where not enough, we just take all of them and compromise on total_mass being fulfilled.
        # Build mask of which first topk to pick and scatter the True/False to correct position in mask
        topk_logits_for_total_mass = torch.arange(total_default_keep_maxnum,
                                                  device=device) <= num_top_totalprob_too_small.unsqueeze(-1)
        mask.scatter_(dim=-1, index=topk_indices, src=topk_logits_for_total_mass)


        if selected_token_mask is not None:
            if selected_token_mask.shape != mask.shape[:-1]:
                raise ValueError("selected_token_mask must have shape matching norm_logits except for last dim.")
            # Always keep the selected logits for each token
            # token_positions = torch.arange(selected_token_mask.size(0), device=selected_token_mask.device)
            # mask[token_positions, selected_token_mask] = True
            mask.scatter_(-1, selected_token_mask.unsqueeze(-1), True)

    indices = mask.nonzero(as_tuple=False).T  # [ndim, nnz], on GPU
    values = norm_logits[mask]  # [nnz], picked values
    values = values.clamp(max=-1e-12)  # ensure no exact zeros

    with torch.no_grad():  # No need to track this in autograd
        norm_logits_shape = norm_logits.shape
        if len(norm_logits_shape) == 3:  # (B, S, V)
            batch_size, seq_len, vocab_size = norm_logits_shape
            total_tokens = batch_size * seq_len
            batch_seq_indices = indices[0] * seq_len + indices[1]
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)
            non_zeros = non_zeros.view(batch_size, seq_len)  # [B, S]
        else:  # (N, V)
            # Support the (potentially sparse) 2D case
            total_tokens, vocab_size = norm_logits_shape
            batch_seq_indices = indices[0]
            non_zeros = torch.bincount(batch_seq_indices, minlength=total_tokens).to(torch.int32)

    zeros_mass = (vocab_size - non_zeros) * default

    # Compute logsumexp for each (batch, seq) pair using scatter operations
    # We use the numerically stable logsumexp: log(sum(exp(x))) = max + log(sum(exp(x - max)))
    max_vals = torch.full((total_tokens,), -float("inf"), device=device, dtype=dtype)
    max_vals = max_vals.scatter_reduce(0, batch_seq_indices, values, reduce="amax")

    # Compute exp(x - max) and sum
    shifted_values = values - max_vals[batch_seq_indices]
    exp_shifted = torch.exp(shifted_values)

    sum_exp = torch.zeros(total_tokens, device=device, dtype=dtype)
    sum_exp = sum_exp.scatter_add(0, batch_seq_indices, exp_shifted)

    # Final logsumexp result (log of remaining mass)
    log_remaining_mass = max_vals + torch.log(sum_exp)
    log_remaining_mass = log_remaining_mass.view(*norm_logits_shape[:-1])  # [batch_size, seq_len] or [nnz,]

    # Compute target mass (1 - zeros_mass) in log space
    log_target_mass = torch.log(1 - zeros_mass)

    # Update sparse values with normalization
    # new_logit = old_logit + log(target_mass) - log(remaining_mass)
    normalized_values = values + log_target_mass.view(-1)[batch_seq_indices] - log_remaining_mass.view(-1)[
        batch_seq_indices]

    # Create normalized sparse tensor
    normalized_sparse = torch.sparse_coo_tensor(
        indices.to(torch.int32).contiguous(), normalized_values.contiguous(), norm_logits_shape, device=device,
        dtype=dtype
    ).coalesce()

    with torch.no_grad():
        # nnz entries
        info_dict = {"remaining_mass": log_remaining_mass.exp(),
                     "num_selected_tokens": non_zeros.float(),
                     }

    if return_sparse:
        return normalized_sparse, info_dict
    else:
        # Convert to dense and set zeros to log(default)
        dense_result = normalized_sparse.to_dense()
        dense_result[dense_result == 0] = log_default
        return dense_result, info_dict

