"""
Shared attention helpers
"""
import torch


# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). 
    The hidden states go from: 
       (batch, num_key_value_heads, seqlen, head_dim) to 
       (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim).contiguous()


def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor, 
                   mask_value: float = -10000) -> torch.Tensor:
    """
    Apply attention mask (e.g., for padding)
    """ 
    if len(attn_mask.shape) == 4:  # attn_mask either (b, h, l, d) or (b, l)
        return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
    else:
        return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)
    

# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors."""
    if position_ids is not None:
        cos, sin = cos[position_ids], sin[position_ids]
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def get_causal_mask(q_len: int, k_len: int, 
              device: torch.device) -> tuple[torch.Tensor]:
    """
    Return masks for softmax and linear attention terms
    -> 1 is include, 0 is ignore
    """
    kwargs = {'device': device, 'dtype': torch.bool}
    causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len)

    return causal_mask

def get_chunk_mask(seq_len: int, chunk_size: int, 
              device: torch.device) -> tuple[torch.Tensor]:
    assert seq_len % chunk_size == 0
    num_chunks = seq_len // chunk_size
    kwargs = {'device': device, 'dtype': torch.bool}

    # Causal mask for one chunk
    chunk_mask = torch.tril(torch.ones(chunk_size, chunk_size, **kwargs))  # [chunk_size, chunk_size]

    eye = torch.eye(num_chunks, **kwargs).bool()  # [num_chunks, num_chunks]
    block_mask = eye[:, :, None, None] & chunk_mask[None, None, :, :]  # [num_chunks, num_chunks, chunk_size, chunk_size]
    full_mask = block_mask.permute(0, 2, 1, 3).reshape(seq_len, seq_len)  # [seq_len, seq_len]

    return full_mask