import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import LlamaAttention


# Constants
IGNORE_INDEX = -100
SPECIAL = {
    "BOS": "$",
    "EOS": "#",
    "PAD": "<PAD>",
    "SEP": "|",
}


def beta_encode(x: int) -> Optional[str]:
    """β-encoding as defined in the paper.
    
    For a position x, β(x) is defined as follows:
    1. If x = 0, β(x) is undefined (return None)
    2. Write x in binary: x = Σᵢ₌₀ᵐ bᵢ·2ⁱ where bₘ = 1
    3. Find the leftmost '0' at position j (if any)
    4. Return the suffix bⱼ₊₁bⱼ₊₂...bₘ (bits after leftmost '0')
    5. If no '0' exists, β(x) is undefined (return None)
    
    Examples from paper:
    - β(2^5 + 2^3 + 2^1) = β(42) = "101010" → leftmost '0' at pos 0 → "01010" 
    - β(2^4 + 2^1) = β(18) = "10010" → leftmost '0' at pos 1 → "010"
    - β(2^4 + 2^0) = β(17) = "10001" → leftmost '0' at pos 1 → "001"
    """
    if x == 0:
        return None
    bits = bin(x)[2:]  # binary string without '0b'
    # find leftmost '0' (reading from left, which is most significant bit)
    try:
        j = bits.index('0')
    except ValueError:
        # no '0' found - all bits are 1, so β(x) is undefined
        return None
    # return suffix after leftmost '0'
    return bits[j+1:]

def build_beta_relation(max_positions: int) -> torch.Tensor:
    """
    Build the β-relation matrix R[i,j] as defined in the paper.
    
    The relation ℜ ⊆ ℕ × ℕ is defined as:
    (i,j) ∈ ℜ ⟺ i ≤ j, i ∈ [1, |β(j)|], and β(j) has '1' at position i
    
    For the matrix R[i,j] (0-indexed storage):
    - i,j are positions (1-indexed in paper, 0-indexed in storage)
    - R[i-1,j-1] = 1 if β(j) has '1' at position i, else 0
    
    Args:
        max_positions: Maximum sequence length to support
        
    Returns:
        Tensor of shape (max_positions, max_positions) with R[i,j] values
    """
    R = torch.zeros((max_positions, max_positions), dtype=torch.float32)
    
    for j in range(1, max_positions + 1):  # j is 1-indexed position
        b = beta_encode(j)
        if b is None:
            continue
            
        # For each bit position in β(j)
        for i, bit_char in enumerate(b, start=1):  # i is 1-indexed bit position
            if bit_char == '1':
                # Store in 0-indexed tensor: R[i-1, j-1] = 1
                if i-1 < max_positions and j-1 < max_positions:
                    R[i-1, j-1] = 1.0
                    
    return R

class LlamaAttentionWithBetaPos(LlamaAttention):
    """
    LLaMA self-attention with β-Relative Positional Encoding (β-RPE).
    
    This implements the β-RPE as described in the paper:
    - Computes standard attention scores: scores_ij = (q_i · k_j) / √d
    - Adds β-relation bias: scores_ij += λ * R[i,j]  
    - Where R[i,j] = 1 iff β(j) has '1' at position i, else 0
    - λ is a learnable scale parameter initialized to beta_scale
    
    The β-relation matrix R encodes the positional relationships derived
    from the β-encoding of sequence positions, enabling the model to 
    capture the structural patterns described in the paper.
    """
    def __init__(self, config, max_positions: int, layer_idx: int = None,
                 beta_scale: float = 1.0, use_beta: bool = True):
        super().__init__(config, layer_idx=layer_idx)
        self.use_beta = bool(use_beta)
        self.register_buffer("beta_rel", build_beta_relation(max_positions), persistent=True)
        self.beta_scale = nn.Parameter(torch.tensor(float(beta_scale)))
        # Ensure compatibility across different HuggingFace transformers versions
        if not hasattr(self, 'num_heads'):
            self.num_heads = config.num_attention_heads
        if not hasattr(self, 'head_dim'):
            self.head_dim = config.hidden_size // config.num_attention_heads
        if not hasattr(self, 'hidden_size'):
            self.hidden_size = config.hidden_size
        if not hasattr(self, 'num_key_value_heads'):
            self.num_key_value_heads = getattr(config, 'num_key_value_heads', self.num_heads)
        if not hasattr(self, 'num_key_value_groups'):
            self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        if not hasattr(self, 'attention_dropout') or not callable(self.attention_dropout):
            self.attention_dropout = nn.Dropout(getattr(config, 'attention_dropout', 0.0))

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()

        # projections
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # [B,T,HD] -> [B,H,T,D]
        q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # GQA: repeat KV if needed
        if self.num_key_value_groups != 1:
            k = torch.repeat_interleave(k, self.num_key_value_groups, dim=1)
            v = torch.repeat_interleave(v, self.num_key_value_groups, dim=1)

        # attention logits
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # [B,H,q_len,kv_len]

        # add mask first
        if attention_mask is not None:
            scores = scores + attention_mask

        # β-RPE bias (after scores are computed)
        if self.use_beta and (self.beta_scale is not None):
            qT, kT = scores.size(-2), scores.size(-1)
            beta_bias = self.beta_rel[:qT, :kT].to(dtype=scores.dtype, device=scores.device)
            scores = scores + self.beta_scale * beta_bias.unsqueeze(0).unsqueeze(0)

        # softmax (+ dropout)
        attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(q.dtype)
        attn = self.attention_dropout(attn)

        # output
        out = torch.matmul(attn, v)  # [B, H, T, D]
        out = out.transpose(1, 2).contiguous()  # [B, T, H, D]
        # Ensure we reshape correctly - get actual dimensions from the tensor
        actual_hidden_size = out.size(-2) * out.size(-1)  # H * D
        out = out.view(bsz, q_len, actual_hidden_size)
        out = self.o_proj(out)
        return (out, attn) if output_attentions else (out, None)

