from __future__ import annotations

import math
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaMLP,
)


# ---------------------------------------------------------------------------
# Helper: build an arbitrary relation matrix R from a predicate
# ---------------------------------------------------------------------------

def build_relation_matrix(
    max_positions: int,
    predicate: Callable[[int, int], bool],
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Build matrix R ∈ {0,1}^{max_positions * max_positions} such that

        R[i-1, j-1] = 1  ⇔  predicate(i, j) is True,

    where i, j are 1-based positions in {1,…,max_positions}.

    This gives you the general "R ⊆ N * N" from the paper.
    You can encode local windows, periodic patterns, etc.

    Example:
        # local window of size W around j
        def local_R(i, j, W=3):
            return abs(i - j) <= W

        R = build_relation_matrix(512, lambda i, j: local_R(i, j, W=3))
    """
    R = torch.zeros((max_positions, max_positions), dtype=dtype, device=device)
    for i in range(1, max_positions + 1):
        for j in range(1, max_positions + 1):
            if predicate(i, j):
                R[i - 1, j - 1] = 1.0
    return R


# ---------------------------------------------------------------------------
# LLaMA Attention with RPE as in Eq. (2)
# ---------------------------------------------------------------------------

class LlamaAttentionWithRPE(LlamaAttention):
    r"""
    LLaMA self-attention with generic Relative Positional Encoding:

        scores_ij = log n · (⟨q_i, k_j⟩ / √d + λ [[R]](i, j)) + mask_ij

    where:
      * n is the context length (kv_len),
      * R is a binary matrix [max_positions, max_positions],
      * λ is a learnable scalar.

    Compared to vanilla LLaMA:
      * We keep the 1/√d scaling.
      * We multiply the whole (dot + λR) by log n, following the paper.
      * We add the usual attention_mask separately (causal / padding).
    """
    def __init__(
        self,
        config,
        relation: torch.Tensor,
        layer_idx: int | None = None,
        lambda_init: float = 1.0,
        use_rpe: bool = True,
        use_log_n_scaling: bool = True,
    ):
        super().__init__(config, layer_idx=layer_idx)

        self.use_rpe = bool(use_rpe)
        self.use_log_n_scaling = bool(use_log_n_scaling)

        # R: [P, P] binary matrix — MUST be on the same device as the model later
        if relation.dim() != 2 or relation.size(0) != relation.size(1):
            raise ValueError(
                f"relation must be a square 2D tensor; got shape {tuple(relation.shape)}"
            )
        self.max_positions = relation.size(0)
        self.register_buffer("relation", relation.to(torch.float32), persistent=True)

        # λ is a learnable scalar bias
        self.rpe_lambda = nn.Parameter(torch.tensor(float(lambda_init)))

        # Ensure cross-version attributes exist
        if not hasattr(self, "num_heads"):
            self.num_heads = config.num_attention_heads
        if not hasattr(self, "head_dim"):
            self.head_dim = config.hidden_size // config.num_attention_heads
        if not hasattr(self, "num_key_value_heads"):
            self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads)
        if not hasattr(self, "num_key_value_groups"):
            self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        # Normalize attention_dropout to a module if necessary
        if not isinstance(getattr(self, "attention_dropout", None), nn.Module):
            p = float(getattr(config, "attention_dropout", 0.0))
            self.attention_dropout = nn.Dropout(p)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ):
        """
        Arguments follow HuggingFace's LlamaAttention, but we ignore past_key_value/use_cache;
        this implementation is for offline training on full sequences.
        """
        bsz, q_len, _ = hidden_states.size()

        # Projections
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # [B, T, H*D] -> [B, H, T, D]
        q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Grouped / multi-query attention
        if self.num_key_value_groups != 1:
            k = torch.repeat_interleave(k, self.num_key_value_groups, dim=1)
            v = torch.repeat_interleave(v, self.num_key_value_groups, dim=1)

        kv_len = k.size(-2)
        assert kv_len == q_len or past_key_value is not None, \
            "This RPE implementation assumes full-sequence self-attention without cache."

        # Dot-product logits: [B, H, T, T]
        scores = torch.matmul(q, k.transpose(-2, -1))  # ⟨q_i, k_j⟩

        # Standard 1/√d scaling
        scores = scores / math.sqrt(self.head_dim)

        # Add λ·[[R]](i,j) term, if enabled
        if self.use_rpe and self.relation is not None:
            qT = scores.size(-2)
            kT = scores.size(-1)
            if qT > self.max_positions or kT > self.max_positions:
                raise ValueError(
                    f"Sequence length {qT} exceeds relation size {self.max_positions}. "
                    "Either increase max_positions or build a larger relation."
                )
            rpe_bias = self.relation[:qT, :kT].to(scores.dtype).to(scores.device)
            rpe_bias = rpe_bias.unsqueeze(0).unsqueeze(0)  # [1,1,T_q,T_k]
            scores = scores + self.rpe_lambda * rpe_bias

        # Multiply everything by log n as in the paper
        if self.use_log_n_scaling and kv_len > 0:
            scores = scores * math.log(float(kv_len))

        # Add standard attention mask (causal & padding)
        if attention_mask is not None:
            scores = scores + attention_mask  # broadcast [B,1,T,T] as usual

        # Softmax + dropout
        attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(q.dtype)
        attn = self.attention_dropout(attn)

        # Attention output
        out = torch.matmul(attn, v)          # [B, H, T, D]
        out = out.transpose(1, 2).contiguous()  # [B, T, H, D]
        bsz, t, h, d = out.shape
        out = out.view(bsz, t, h * d)        # merge heads

        out = self.o_proj(out)

        if use_cache:
            past_key_value = None  # not implemented

        return (out, None, past_key_value) if output_attentions else (out, None)


# ---------------------------------------------------------------------------
# Utility: patch a LLaMA model with generic RPE attention
# ---------------------------------------------------------------------------

def replace_llama_attn_with_rpe(
    model: nn.Module,
    relation: torch.Tensor,
    lambda_init: float = 1.0,
    use_rpe: bool = True,
    use_log_n_scaling: bool = True,
):
    """
    In-place replace all LlamaAttention modules in a LLaMA model with
    LlamaAttentionWithRPE using a fixed relation matrix R.

    Args:
        model: a LlamaForCausalLM or LlamaModel instance.
        relation: [P, P] tensor with entries 0/1 (or floats in {0,1}).
        lambda_init: initial value for λ.
        use_rpe: if False, relation is ignored (reduces to log n·⟨q,k⟩).
        use_log_n_scaling: if False, skip the log n factor.
    """

    if relation.dim() != 2 or relation.size(0) != relation.size(1):
        raise ValueError(
            f"relation must be a square 2D tensor; got shape {tuple(relation.shape)}"
        )

    def _replace(mod: nn.Module, layer_idx: int = 0) -> int:
        for child_name, child in list(mod.named_children()):
            if isinstance(child, LlamaAttention):
                # device of existing attention (or model fallback)
                try:
                    child_device = next(child.parameters()).device
                except StopIteration:
                    child_device = next(model.parameters()).device

                rel_on_device = relation.to(child_device)

                new_attn = LlamaAttentionWithRPE(
                    config=child.config,
                    relation=rel_on_device,
                    layer_idx=layer_idx,
                    lambda_init=lambda_init,
                    use_rpe=use_rpe,
                    use_log_n_scaling=use_log_n_scaling,
                ).to(child_device)

                # copy Q,K,V,O projections
                new_attn.q_proj.load_state_dict(child.q_proj.state_dict())
                new_attn.k_proj.load_state_dict(child.k_proj.state_dict())
                new_attn.v_proj.load_state_dict(child.v_proj.state_dict())
                new_attn.o_proj.load_state_dict(child.o_proj.state_dict())

                setattr(mod, child_name, new_attn)
                layer_idx += 1
            else:
                layer_idx = _replace(child, layer_idx)
        return layer_idx

    _replace(model)


# ---------------------------------------------------------------------------
# Utility: replace all LlamaMLP blocks to use ReLU activation
# ---------------------------------------------------------------------------

class LlamaMLPWithReLU(LlamaMLP):
    """
    Drop-in replacement for HuggingFace's LlamaMLP that uses ReLU as the
    activation function instead of the default (typically SwiGLU/SiLU).
    """
    def __init__(self, config):
        super().__init__(config)
        # HF stores act_fn as a callable; we override it with ReLU.
        self.act_fn = nn.ReLU()


def replace_llama_mlp_activation_with_relu(model: nn.Module):
    """
    In-place replacement of every LlamaMLP in `model` by LlamaMLPWithReLU.

    This changes only the feedforward activation; attention remains unchanged.
    It is HF-version-robust because we subclass LlamaMLP and re-use its
    parameters (input/output/hidden sizes, etc.) while swapping act_fn.
    """

    def _replace_mlp(module: nn.Module):
        for name, child in list(module.named_children()):
            # Check by type rather than name to be robust.
            if isinstance(child, LlamaMLP):
                try:
                    device = next(child.parameters()).device
                except StopIteration:
                    device = next(model.parameters()).device

                # Create new MLP with ReLU and copy weights.
                new_mlp = LlamaMLPWithReLU(child.config).to(device)
                # strict=False to be robust across HF versions that add buffers.
                new_mlp.load_state_dict(child.state_dict(), strict=False)

                setattr(module, name, new_mlp)
            else:
                _replace_mlp(child)

    _replace_mlp(model)
