"""FMA attention approximation wrapper for Llama."""

import sys
from pathlib import Path

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import flax.linen as nn
from typing import Optional
from functools import partial

# Add the FMA repository to path
FMA_PATH = Path(__file__).parent.parent.parent.parent / "fma"
if str(FMA_PATH) not in sys.path:
    sys.path.insert(0, str(FMA_PATH))

# Import FMA attention implementations
# TODO: Update these imports based on which FMA attention variant you want to use
# from fma.pallas_retrieval import retrieval_attention
# from fma.single_level_attention import single_level_attention
from fma.pallas_retrieval import causal_attn as fma_attn


def apply_rotary_pos_emb(q, k, position_ids, rope_theta=500000.0, rope_scaling=None):
    """Apply RoPE (Rotary Position Embeddings) to query and key tensors.

    Args:
        q: Query tensor of shape (batch, seq_len, num_heads, head_dim)
        k: Key tensor of shape (batch, seq_len, num_kv_heads, head_dim)
        position_ids: Position IDs of shape (batch, seq_len)
        rope_theta: Base for rotary embeddings
        rope_scaling: Optional dict with scaling config (for llama3 scaling)

    Returns:
        Tuple of (q_rotated, k_rotated)
    """
    head_dim = q.shape[-1]
    output_dtype = q.dtype  # Remember output dtype (bf16)

    # Compute base frequency tensor in fp32 for precision
    dim_indices = jnp.arange(0, head_dim, 2, dtype=jnp.float32)
    inv_freq = 1.0 / (rope_theta ** (dim_indices / head_dim))

    # Apply llama3 RoPE scaling if configured
    # NOTE: Scaling should only apply beyond original context length
    seq_len = position_ids.shape[1]
    should_scale = (rope_scaling is not None and
                   rope_scaling.get('rope_type') == 'llama3' and
                   seq_len > rope_scaling['original_max_position_embeddings'])

    if should_scale:
        factor = rope_scaling['factor']
        low_freq_factor = rope_scaling['low_freq_factor']
        high_freq_factor = rope_scaling['high_freq_factor']
        old_context_len = rope_scaling['original_max_position_embeddings']

        # Compute wavelength thresholds
        # See: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
        low_freq_wavelen = old_context_len / low_freq_factor
        high_freq_wavelen = old_context_len / high_freq_factor

        # Convert inverse frequencies to wavelengths
        wavelens = 2 * jnp.pi / inv_freq

        # Three-tier scaling approach:
        # 1. High frequencies (short wavelengths < high_freq_wavelen): no scaling
        # 2. Low frequencies (long wavelengths > low_freq_wavelen): full scaling (divide by factor)
        # 3. Medium frequencies: smooth interpolation

        # Start with full scaling for low frequencies
        inv_freq_llama = jnp.where(wavelens > low_freq_wavelen, inv_freq / factor, inv_freq)

        # Apply smooth interpolation for medium frequencies
        smooth_factor = (old_context_len / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor)
        smoothed_inv_freq = (1 - smooth_factor) * (inv_freq_llama / factor) + smooth_factor * inv_freq_llama
        is_medium_freq = (wavelens >= high_freq_wavelen) & (wavelens <= low_freq_wavelen)
        inv_freq = jnp.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

    # Compute position encodings in fp32 for precision
    # position_ids: (batch, seq_len)
    # inv_freq: (head_dim // 2,)
    freqs = jnp.einsum('bi,j->bij', position_ids.astype(jnp.float32), inv_freq)  # (batch, seq_len, head_dim // 2)

    # Create rotation matrix
    emb = jnp.concatenate([freqs, freqs], axis=-1)  # (batch, seq_len, head_dim)
    cos = jnp.cos(emb)
    sin = jnp.sin(emb)

    # Cast cos/sin to output dtype (bf16) for final rotation
    cos = cos.astype(output_dtype)
    sin = sin.astype(output_dtype)

    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return jnp.concatenate([-x2, x1], axis=-1)

    # Apply rotation
    # q, k: (batch, seq_len, num_heads, head_dim)
    # cos, sin: (batch, seq_len, head_dim)
    cos = cos[:, :, None, :]  # (batch, seq_len, 1, head_dim)
    sin = sin[:, :, None, :]

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


class FMAAttention(nn.Module):
    """FMA attention layer with Grouped Query Attention (GQA)."""

    config: object  # LlamaConfig

    @nn.compact
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        position_ids: Optional[jnp.ndarray] = None,
        is_causal: bool = True,
    ) -> jnp.ndarray:
        """Apply FMA attention.

        Args:
            hidden_states: Input of shape (batch, seq_len, hidden_size)
            position_ids: Optional position IDs of shape (batch, seq_len)
            is_causal: Whether to use causal masking (default: True for LM)

        Returns:
            Output of shape (batch, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = hidden_states.shape

        head_dim = self.config.hidden_size // self.config.num_attention_heads
        num_kv_heads = self.config.num_key_value_heads
        num_q_heads = self.config.num_attention_heads

        # Q, K, V projections
        q_proj = nn.Dense(
            num_q_heads * head_dim,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, "model"),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )
        k_proj = nn.Dense(
            num_kv_heads * head_dim,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, "model"),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )
        v_proj = nn.Dense(
            num_kv_heads * head_dim,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, "model"),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )
        o_proj = nn.Dense(
            self.config.hidden_size,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P("model", None),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )

        query_states = q_proj(hidden_states)
        key_states = k_proj(hidden_states)
        value_states = v_proj(hidden_states)

        # Reshape to (batch, seq_len, num_heads, head_dim)
        query_states = query_states.reshape(batch_size, seq_len, num_q_heads, head_dim)
        query_states = jax.lax.with_sharding_constraint(query_states, P("data", None, "model", None))
        key_states = key_states.reshape(batch_size, seq_len, num_kv_heads, head_dim)
        key_states = jax.lax.with_sharding_constraint(key_states, P("data", None, "model", None))
        value_states = value_states.reshape(batch_size, seq_len, num_kv_heads, head_dim)
        value_states = jax.lax.with_sharding_constraint(value_states, P("data", None, "model", None))

        # Apply RoPE
        if position_ids is None:
            position_ids = jnp.arange(seq_len)[None, :].repeat(batch_size, axis=0)

        query_states, key_states = apply_rotary_pos_emb(
            query_states,
            key_states,
            position_ids,
            rope_theta=self.config.rope_theta,
            rope_scaling=self.config.rope_scaling,
        )

        # Handle GQA: repeat key/value heads to match query heads
        if num_kv_heads != num_q_heads:
            num_repeats = num_q_heads // num_kv_heads
            key_states = jnp.repeat(key_states, num_repeats, axis=2)
            value_states = jnp.repeat(value_states, num_repeats, axis=2)

        # query_states, key_states, value_states already have shape (batch, seq_len, num_heads, head_dim)
        # which is exactly what jax.nn.dot_product_attention expects - no transpose needed!

        # TODO: Apply FMA attention approximation here when config.use_fma_attention is True
        # For now, use Flash Attention (cuDNN) as baseline
        if self.config.use_fma_attention:
            Q = self.config.fma_num_clusters
            K = self.config.fma_num_clusters
            B = self.config.fma_block_size
            NR = self.config.fma_num_retrievals
            bidiagonal = self.config.fma_bidiagonal
            dipole = self.config.fma_dipole
            bmha = jax.vmap(jax.vmap(partial(fma_attn, Q, K, B, NR, bidiagonal, dipole), in_axes=1, out_axes=1))
            bmha = jax.jit(bmha)
            bmha = jax.shard_map(bmha, in_specs=P("data", None, "model", None), out_specs=(P("data", None, "model"), P("data", None, "model", None)), check_vma=False)
            _, attn_output = bmha(
                query_states,
                key_states,
                value_states,
            )
        else:
            # Use JAX's Flash Attention implementation via cuDNN
            # This has O(n) memory complexity instead of O(n²)
            attn_output = jax.nn.dot_product_attention(
                query_states,
                key_states,
                value_states,
                is_causal=is_causal,
                implementation='cudnn',
            )

        # attn_output shape: (batch, seq_len, num_heads, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_len, self.config.hidden_size)
        attn_output = jax.lax.with_sharding_constraint(attn_output, P("data", None, "model"))

        # Output projection
        attn_output = o_proj(attn_output)

        return attn_output
