"""Llama 3.2 1B model implementation in Flax."""

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import flax.linen as nn
from typing import Optional, Tuple

from fma_llama.model.config import LlamaConfig
from fma_llama.attention.fma_attention import FMAAttention


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""

    eps: float = 1e-6

    @nn.compact
    def __call__(self, x):
        weight = self.param('weight', nn.initializers.ones, (x.shape[-1],))
        variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
        x = x * jax.lax.rsqrt(variance + self.eps)
        return weight * x


class LlamaMLP(nn.Module):
    """MLP block for Llama model."""

    config: LlamaConfig

    @nn.compact
    def __call__(self, x):
        gate_proj = nn.Dense(
            self.config.intermediate_size,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, 'model'),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )
        up_proj = nn.Dense(
            self.config.intermediate_size,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, 'model'),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )
        down_proj = nn.Dense(
            self.config.hidden_size,
            use_bias=False,
            kernel_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P('model', None),
            ),
            dtype=self.config.dtype,
            param_dtype=self.config.param_dtype,
        )

        # SwiGLU activation: silu(gate) * up
        gate = gate_proj(x)
        up = up_proj(x)
        hidden = nn.silu(gate) * up
        hidden = jax.lax.with_sharding_constraint(
            hidden,
            P('data', None, 'model'),
        )
        output = down_proj(hidden)

        return output


class LlamaDecoderLayer(nn.Module):
    """Transformer decoder layer for Llama."""

    config: LlamaConfig

    @nn.compact
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        position_ids: Optional[jnp.ndarray] = None,
        is_causal: bool = True,
    ) -> jnp.ndarray:
        # Self-attention with residual
        hidden_states = jax.lax.with_sharding_constraint(
            hidden_states,
            P('data', 'model', None),
        )
        residual = hidden_states
        hidden_states = RMSNorm(eps=self.config.rms_norm_eps)(hidden_states)

        # FMAAttention handles both FMA and standard attention based on config
        attn_output = FMAAttention(self.config)(
            hidden_states,
            position_ids=position_ids,
            is_causal=is_causal,
        )

        hidden_states = residual + attn_output
        hidden_states = jax.lax.with_sharding_constraint(
            hidden_states,
            P('data', 'model', None),
        )

        # MLP with residual
        residual = hidden_states
        hidden_states = RMSNorm(eps=self.config.rms_norm_eps)(hidden_states)
        hidden_states = LlamaMLP(self.config)(hidden_states)
        hidden_states = jax.lax.with_sharding_constraint(
            hidden_states,
            P('data', 'model', None),
        )
        hidden_states = residual + hidden_states
        hidden_states = jax.lax.with_sharding_constraint(
            hidden_states,
            P('data', 'model', None),
        )

        return hidden_states


class LlamaModel(nn.Module):
    """Llama 3.2 1B model with FMA attention."""

    config: LlamaConfig

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.ndarray,
        position_ids: Optional[jnp.ndarray] = None,
        is_causal: bool = True,
    ) -> jnp.ndarray:
        """Forward pass.

        Args:
            input_ids: Input token IDs of shape (batch_size, seq_len)
            position_ids: Optional position IDs of shape (batch_size, seq_len)
            is_causal: Whether to use causal masking (default: True for LM)

        Returns:
            hidden_states: Final hidden states of shape (batch_size, seq_len, hidden_size)
        """
        # Token embeddings
        embed_tokens = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.hidden_size,
            embedding_init=nn.with_partitioning(
                nn.initializers.normal(self.config.initializer_range),
                P(None, 'model'),
            ),
            dtype=self.config.dtype,  # Output in bf16
            param_dtype=self.config.param_dtype,  # Params in fp32
        )
        hidden_states = embed_tokens(input_ids)

        # Apply transformer layers
        for _ in range(self.config.num_hidden_layers):
            hidden_states = LlamaDecoderLayer(self.config)(
                hidden_states,
                position_ids=position_ids,
                is_causal=is_causal,
            )

        # Final layer norm
        hidden_states = RMSNorm(eps=self.config.rms_norm_eps)(hidden_states)

        return hidden_states


class LlamaForCausalLM(nn.Module):
    """Llama model with language modeling head."""

    config: LlamaConfig

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.ndarray,
        position_ids: Optional[jnp.ndarray] = None,
        is_causal: bool = True,
    ) -> jnp.ndarray:
        """Forward pass with LM head.

        Args:
            input_ids: Input token IDs of shape (batch_size, seq_len)
            position_ids: Optional position IDs of shape (batch_size, seq_len)
            is_causal: Whether to use causal masking (default: True for LM)

        Returns:
            logits: Token logits of shape (batch_size, seq_len, vocab_size)
        """
        hidden_states = LlamaModel(self.config)(
            input_ids,
            position_ids=position_ids,
            is_causal=is_causal,
        )

        # Cast back to fp32 for LM head for numerical stability in logits
        hidden_states = hidden_states.astype(jnp.float32)

        # Language modeling head
        # Use fp32 for LM head since we cast hidden states to fp32 above
        if self.config.tie_word_embeddings:
            # Share weights with embedding layer
            # Note: In practice, you'd access the embedding kernel directly
            lm_head = nn.Dense(
                self.config.vocab_size,
                use_bias=False,
                kernel_init=nn.with_partitioning(
                    nn.initializers.normal(self.config.initializer_range),
                    P(None, 'model'),
                ),
                dtype=jnp.float32,
                param_dtype=self.config.param_dtype,
            )
        else:
            lm_head = nn.Dense(
                self.config.vocab_size,
                use_bias=False,
                kernel_init=nn.with_partitioning(
                    nn.initializers.normal(self.config.initializer_range),
                    P(None, 'model'),
                ),
                dtype=jnp.float32,
                param_dtype=self.config.param_dtype,
            )

        logits = lm_head(hidden_states)
        logits = jax.lax.with_sharding_constraint(
            logits,
            P('data', 'model', None),
        )
        return logits
