# sparse_attention.py
# Implements Tsparse: sparse attention combining local window and selected global memory positions.
# The function tsparse_forward computes attention outputs where each query attends only to
# positions in S_i = local_window_{i} ∪ mem_positions (global positions).
#
# This file provides a batched implementation for typical transformer-like Q/K/V tensors.

from typing import Optional, Union, List
import torch
import torch.nn.functional as F


def tsparse_forward(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    local_window: int = 5,
    mem_positions: Optional[Union[List[int], torch.Tensor]] = None,
    mask: Optional[torch.Tensor] = None,
    dropout: float = 0.0,
) -> torch.Tensor:
    """
    Compute sparse attention outputs.
    Args:
      query: (B, Tq, D)
      key:   (B, Tk, D)
      value: (B, Tk, Dv)
      local_window: int, radius of local attention (each query i can see positions [i-local_window, i+local_window])
      mem_positions: optional list/tensor of global positions (indices in [0, Tk)) that should be included for all queries
                     e.g., memory slot indices concatenated to the sequence positions.
      mask: optional boolean mask (B, Tq, Tk) where True means allowed; if provided, it is ANDed with computed sparsity mask.
      dropout: attention dropout prob (applied to attention weights)
    Returns:
      output: (B, Tq, Dv)
    """
    assert query.dim() == 3 and key.dim() == 3 and value.dim() == 3
    B, Tq, D = query.shape
    _, Tk, Dk = key.shape
    _, Tv, Dv = value.shape
    assert Tk == Tv

    # build base allowed mask: start with zeros and set True where allowed
    device = query.device
    allowed = torch.zeros((B, Tq, Tk), dtype=torch.bool, device=device)

    # local window indices for each query position
    # assume queries and keys share alignment (e.g., both are same length before mem concatenation)
    for i in range(Tq):
        start = max(0, i - local_window)
        end = min(Tk, i + local_window + 1)
        allowed[:, i, start:end] = True

    # include mem_positions globally for all queries if provided
    if mem_positions is not None:
        if isinstance(mem_positions, (list, tuple)):
            mem_pos = torch.tensor(mem_positions, dtype=torch.long, device=device)
        else:
            mem_pos = mem_positions.to(device)
        # clamp mem positions into valid range
        mem_pos = mem_pos[(mem_pos >= 0) & (mem_pos < Tk)]
        if mem_pos.numel() > 0:
            allowed.index_fill_(2, mem_pos, True)

    # apply external mask if provided
    if mask is not None:
        # mask expected as True for allowed; combine with our allowed mask
        allowed = allowed & mask.to(torch.bool)

    # compute scaled dot-product only for allowed positions
    # naive: compute full attention but mask out with -inf to disallow; for large sequences you would implement efficient sparse kernels
    scale = 1.0 / (D ** 0.5)
    # compute raw scores (B, Tq, Tk)
    scores = torch.matmul(query, key.transpose(-2, -1)) * scale

    # mask out disallowed positions by setting to -inf
    neg_inf = -1e9
    scores = torch.where(allowed, scores, torch.tensor(neg_inf, device=device, dtype=scores.dtype))

    # softmax along key dimension
    attn = F.softmax(scores, dim=-1)  # (B, Tq, Tk)

    if dropout > 0.0:
        attn = F.dropout(attn, p=dropout, training=True)

    out = torch.matmul(attn, value)  # (B, Tq, Dv)
    return out
