import math

import torch

from discrete_trpl.sparse.kl import sparse_kl_aligned


def sparse_flatten(x: torch.Tensor, dims: tuple[int, int]) -> torch.Tensor:
    """
    Flatten sparse COO tensor along dimensions [start, end] into one dimension.
    Example: sparse_flatten(x, (1,2)) turns shape (a,b,c) -> (a, b*c).
    """
    if x.layout != torch.sparse_coo:
        raise ValueError("Expected a COO sparse tensor.")
    if x.dense_dim() != 0:
        raise ValueError("Only tensors with dense_dim == 0 supported.")

    start, end = dims
    old_shape = tuple(x.shape)
    if not (0 <= start <= end < len(old_shape)):
        raise ValueError("Invalid dims range.")

    # new shape
    flat_size = math.prod(old_shape[start:end + 1])
    new_shape = old_shape[:start] + (flat_size,) + old_shape[end + 1:]

    idx = x.indices()  # (ndim, nnz)
    vals = x.values()

    if x._nnz() == 0:
        new_idx = torch.zeros((len(new_shape), 0), dtype=torch.long, device=x.device)
        return torch.sparse_coo_tensor(new_idx, vals, size=new_shape,
                                       dtype=x.dtype, device=x.device)

    # flatten sub-indices [start:end] -> one index
    sub_idx = idx[start:end + 1]  # (k, nnz), where k = end-start+1
    # strides[k] = product of following dimensions
    strides = [math.prod(old_shape[d + 1:end + 1]) for d in range(start, end + 1)]
    strides = torch.tensor(strides, device=idx.device, dtype=torch.long)  # (k,)
    flat_idx = (sub_idx.t() * strides).sum(dim=1)  # (nnz,)

    # assemble new indices
    new_idx = torch.cat([
        idx[:start],
        flat_idx.unsqueeze(0),
        idx[end + 1:]
    ], dim=0)

    return torch.sparse_coo_tensor(new_idx, vals, size=new_shape,
                                   dtype=x.dtype, device=x.device).coalesce()


def subindex_required_projections(
        target_aligned: torch.Tensor,
        ref_aligned: torch.Tensor,
        bound: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Select only rows (batch elements) that need projection and build compact sparse
    tensors containing just those rows, preserving gradients.

    Args:
        target_aligned: COO sparse tensor (..., V) aligned with ref_aligned (same indices).
        ref_aligned:    COO sparse tensor (..., V) aligned with target_aligned.
        bound:          Tensor of shape (..., 1) broadcastable to batch dims.

    Returns:
        unique_batch_indices: 1D int64 tensor with the *flattened* original batch indices
                              (range [0, prod(batch_dims))) that require projection.
        target_filtered:      COO sparse tensor (K, V) with only those rows (K=len(unique)).
        ref_filtered:         COO sparse tensor (K, V) with only those rows.
        kl_div:               Tensor with KL divergence for each batch element.
        _bound:               Tensor (K, 1) with bounds for those rows.
    """
    # Compute kl divergence to figure out which logits need to be projected
    with torch.no_grad():
        kl_div = sparse_kl_aligned(target_aligned, ref_aligned)  # shape: batch_dims
        needs_projection = kl_div >= bound[..., 0]

    shape = target_aligned.shape
    batch_dims = shape[:-1]
    V = shape[-1]
    device = target_aligned.device
    dtype = target_aligned.dtype

    # Nothing to project: return empties with consistent dtypes/devices
    if not torch.any(needs_projection):
        empty_indices = torch.zeros((2, 0), dtype=torch.long, device=device)
        empty_values = torch.zeros((0,), dtype=dtype, device=device)
        empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (0, V),
                                               dtype=dtype, device=device)
        return (torch.empty(0, dtype=torch.long, device=device),
                empty_sparse, empty_sparse, kl_div,
                torch.empty((0, 1), dtype=bound.dtype, device=bound.device))

    # Build flattened batch index per nonzero column
    idx = target_aligned.indices()  # [ndim, NNZ], last dim is logits
    if len(batch_dims) == 1:
        flat_batch_idx = idx[0].long()  # [NNZ]
    else:
        # ravel multi-d batch coords
        strides = torch.tensor(
            [int(math.prod(batch_dims[i + 1:])) for i in range(len(batch_dims))],
            device=device, dtype=torch.long
        )  # [nd_batch]
        flat_batch_idx = (idx[:-1].long() * strides[:, None]).sum(dim=0)  # [NNZ]

    # Mask nnz entries that belong to rows needing projection
    needs_proj_flat = needs_projection.view(-1)  # [B]
    sparse_mask = needs_proj_flat[flat_batch_idx]  # [NNZ] boolean

    if not torch.any(sparse_mask):
        empty_indices = torch.zeros((2, 0), dtype=torch.long, device=device)
        empty_values = torch.zeros((0,), dtype=dtype, device=device)
        empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (0, V),
                                               dtype=dtype, device=device)
        return (torch.empty(0, dtype=torch.long, device=device),
                empty_sparse, empty_sparse, kl_div,
                torch.empty((0, 1), dtype=bound.dtype, device=bound.device))

    # Filter nnz columns
    filt_idx = idx[:, sparse_mask]  # [ndim, NNZ']
    t_vals = target_aligned.values()[sparse_mask]
    r_vals = ref_aligned.values()[sparse_mask]

    # Compute unique original flat batches present in the filtered set and remap to 0..K-1
    filt_flat_batches = flat_batch_idx[sparse_mask]  # [NNZ']
    unique_batches, inv = torch.unique(filt_flat_batches, return_inverse=True)  # [K], [NNZ']

    # New indices: (row=inv, col=logit)
    new_indices = torch.stack([inv, filt_idx[-1].long()], dim=0)  # [2, NNZ']
    new_shape = (unique_batches.numel(), V)

    target_filtered = torch.sparse_coo_tensor(new_indices, t_vals, new_shape,
                                              dtype=dtype, device=device).coalesce()
    ref_filtered = torch.sparse_coo_tensor(new_indices, r_vals, new_shape,
                                           dtype=dtype, device=device).coalesce()

    # Select bounds for those rows (use flat indexing)
    _bound = bound.view(-1, 1)[unique_batches]

    return unique_batches, target_filtered, ref_filtered, kl_div, _bound


def sparse_where(
        condition: torch.Tensor,
        input: torch.Tensor,
        other: torch.Tensor
) -> torch.Tensor:
    """
    Perform a sparse element-wise selection between two sparse tensors.

    Args:
        condition (torch.Tensor): A boolean (dense) mask tensor indicating where to take
            values from `input` (True) or from `other` (False). Must be broadcastable to
            the shape of `input` and `other`.
        input (torch.sparse_coo_tensor): Sparse tensor to select values from when
            `condition` is True.
        other (torch.sparse_coo_tensor): Sparse tensor to select values from when
            `condition` is False.

    Returns:
        torch.sparse_coo_tensor: A coalesced sparse tensor where entries are taken
        from `input` or `other` according to `condition`, with explicit zeros removed
        to preserve sparsity.

    Notes:
        - Both `input` and `other` must have the same shape.
        - Explicit zeros in the result are removed so that the returned tensor
          reflects true sparsity.
    """
    mask = condition.to(dtype=input.dtype, device=input.device)
    result = (input.mul(mask) + other.mul(1.0 - mask)).coalesce()

    # Remove explicit zeros to match the actual sparsity
    nonzero_mask = result.values() != 0
    if nonzero_mask.sum() < len(result.values()):
        filtered_values = result.values()[nonzero_mask]
        filtered_indices = result.indices()[:, nonzero_mask]
        result = torch.sparse_coo_tensor(filtered_indices, filtered_values, result.shape).coalesce()
    return result


def sparse_mask_3d_to_2d(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    x: sparse_coo Tensor of shape (bsz, seq_len, vocab_size)
    mask: Dense bool tensor of shape (bsz, seq_len), indicating the rows to keep
    Returns: sparse_coo (num_tokens, vocab_size) with num_tokens = mask.sum().
        Iff there was no vocab outside of mask, i.e., if the sparse tensor did not contain any elements in masked-out
        parts of the sequences, we get out._nnz()=x._nnz(), i.e., we preserve all nonzeros.
        Otherwise, we may drop some nonzeros that were in masked-out rows.
    """
    assert x.layout == torch.sparse_coo, "X must be sparse COO"
    bsz, seq_len, vocab_size = x.shape
    x = x.coalesce()

    idx = x.indices()  # (3, nnz): [b, t, v]
    val = x.values()  # (nnz,)

    # flatten (b,t) -> bt
    flat_position = idx[0] * seq_len + idx[1]  # (nnz,)
    flat_mask = mask.reshape(-1).to(torch.bool)  # (B*T,)

    keep = flat_mask[flat_position]  # (nnz,)

    num_tokens = int(flat_mask.sum().item())

    # compact (bsz*seq_len) row ids 0..N-1 using the mask
    row_ids = torch.full((bsz * seq_len,), -1, dtype=torch.long, device=idx.device)
    row_ids[flat_mask] = torch.arange(num_tokens, device=idx.device)

    new_row = row_ids[flat_position[keep]]  # (kept_nnz,)
    new_col = idx[2][keep]  # (kept_nnz,)
    new_val = val[keep]  # (kept_nnz,)

    new_idx = torch.stack([new_row, new_col], dim=0)
    return torch.sparse_coo_tensor(new_idx, new_val, (num_tokens, vocab_size), device=val.device).coalesce()

def sparse_remap(coo: torch.Tensor, row_map: torch.Tensor, out_rows: int) -> torch.Tensor:
    """Map rows of a [K, V] COO to [out_rows, V] via row_map[k] -> new row."""
    assert coo.layout == torch.sparse_coo
    idx = coo.indices()  # [2, nnz]
    vals = coo.values()
    new_rows = row_map[idx[0]]  # [nnz]
    new_idx = torch.stack([new_rows, idx[1]], dim=0)
    return torch.sparse_coo_tensor(new_idx, vals, (out_rows, coo.size(1))).coalesce()


