import torch
import torch.nn as nn


class AttentionSinks(nn.Module):
    def __init__(self, dim, num_sinks=1):
        super().__init__()
        self.num_sinks = num_sinks
        # Learnable sink embeddings for keys and values
        self.sink_k = nn.Parameter(torch.randn(num_sinks, dim))
        self.sink_v = nn.Parameter(torch.randn(num_sinks, dim))

    def forward(self, k, v):
        """Add attention sinks to both k and v tensors.

        Args:
            k: Tensor of shape [batch_size, num_heads, seq_len, head_dim]
            v: Tensor of shape [batch_size, num_heads, seq_len, head_dim]

        Returns:
            k_with_sinks: Tensor of shape [batch_size, num_heads, seq_len + num_sinks, head_dim]
            v_with_sinks: Tensor of shape [batch_size, num_heads, seq_len + num_sinks, head_dim]
        """
        batch_size, num_heads, seq_len, head_dim = k.shape
        # Expand sinks for batch and heads
        sinks_k = self.sink_k.view(1, num_heads, self.num_sinks, head_dim).expand(batch_size, -1, -1, -1)
        sinks_v = self.sink_v.view(1, num_heads, self.num_sinks, head_dim).expand(batch_size, -1, -1, -1)
        k_with_sinks = torch.cat([sinks_k, k], dim=2)
        v_with_sinks = torch.cat([sinks_v, v], dim=2)
        return k_with_sinks, v_with_sinks

    