import math
import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional, Tuple, Callable
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb

def get_alibi_biases(n_heads: int, mask: torch.Tensor) -> torch.Tensor:
    """
    Generate ALiBi biases for the given mask.
    This is a placeholder implementation - you may need to adjust based on your specific ALiBi implementation.
    """
    seq_len = mask.shape[-1]
    # Create slopes for each head (standard ALiBi approach)
    slopes = torch.tensor([2**(-8*i/n_heads) for i in range(n_heads)], device=mask.device, dtype=mask.dtype)
    
    # Create position differences
    positions = torch.arange(seq_len, device=mask.device).unsqueeze(0) - torch.arange(seq_len, device=mask.device).unsqueeze(1)
    
    # Apply slopes to position differences
    alibi_biases = slopes.view(-1, 1, 1) * positions.unsqueeze(0)
    
    return alibi_biases * mask

def farsight_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    """
    Implements FarSight with square attention: assumes q_len == seq_len.
    """

    # Repeat KV heads for grouped attention
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    bsz, n_heads, seq_len, _ = query.shape

    # Raw dot-product attention scores
    scores = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(query.shape[-1])

    # Causal mask C: lower triangular (keep causal decoding)
    C = torch.tril(torch.ones((seq_len, seq_len), device=query.device, dtype=query.dtype))  # [S, S]
    C = C.view(1, 1, seq_len, seq_len)

    # Get register scores
    register_score = module.register_score(seq_len, n_heads, C)

    # Apply FarSight formula: scores * C * sigma + register_score
    scores = scores * C * module.decay_sigma + register_score

    # Apply softmax and remove register score to keep causal decoding
    attn_probs = torch.softmax(scores, dim=-1) * C

    # Apply dropout
    attn_probs = nn.functional.dropout(attn_probs, p=dropout, training=module.training)

    # Attention output
    attn_output = torch.matmul(attn_probs, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_probs


class FarSightLlamaAttention(LlamaAttention):
    
    def __init__(self, config: LlamaConfig, layer_idx: int, max_seq_len: int = 256, alpha: int = 1024):
        super().__init__(config, layer_idx)
        self.max_seq_len = max_seq_len
        self.decay_sigma = torch.log(torch.tensor(max_seq_len))/math.log(alpha)

    def register_score(self, seq_len: int, n_heads: int, mask: torch.Tensor) -> torch.Tensor:
        """
        Generate register scores following the pattern from the reference implementation.
        """
        # Create a register (upper-triangular matrix with 0)
        register = 1 - torch.triu(torch.full((seq_len, seq_len), 1, device=mask.device, dtype=mask.dtype), diagonal=1)
        
        # Generate register alibi biases
        register_score = get_alibi_biases(n_heads, -register.flip(dims=[1])).flip(dims=[1])
        
        # Final register score adjustment
        return register_score.contiguous() * (1 - mask)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        bsz, q_len, _ = hidden_states.size()
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        # Query, key, value projection
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_output, attn_weights = farsight_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights