"""Configuration for Llama 3.2 1B model."""

from dataclasses import dataclass
from typing import Optional
import jax.numpy as jnp


@dataclass
class LlamaConfig:
    """Configuration class for Llama 3.2 1B model.

    Based on the official Llama 3.2 1B architecture:
    - https://huggingface.co/meta-llama/Llama-3.2-1B
    """

    # Model architecture
    vocab_size: int = 128256  # Llama 3.2 vocabulary size
    hidden_size: int = 2048
    intermediate_size: int = 8192
    num_hidden_layers: int = 16
    num_attention_heads: int = 32
    num_key_value_heads: int = 8  # GQA (Grouped Query Attention)
    max_position_embeddings: int = 131072  # 128K context length

    # RoPE configuration
    rope_theta: float = 500000.0
    rope_scaling: Optional[dict] = None  # Will be populated from HF config

    # Normalization
    rms_norm_eps: float = 1e-5

    # Dropout (typically 0 for inference)
    attention_dropout: float = 0.0
    hidden_dropout: float = 0.0

    # Activation function
    hidden_act: str = "silu"

    # FMA-specific configuration
    use_fma_attention: bool = True
    fma_block_size: int = 2^13  # Block size for FMA approximation
    fma_num_clusters: int = 128  # Number of clusters for approximation
    fma_num_retrievals: int = 8  # Number of retrievals for FMA approximation
    fma_bidiagonal: bool = False  # Use bidiagonal approximation for FMA
    fma_dipole: bool = False  # Use dipole approximation for FMA

    # Tie word embeddings
    tie_word_embeddings: bool = True

    # Initialization
    initializer_range: float = 0.02

    # Precision configuration
    dtype: object = jnp.bfloat16  # Computation dtype (activations)
    param_dtype: object = jnp.float32  # Parameter dtype (weights)

    @classmethod
    def from_pretrained(cls, model_name: str = "meta-llama/Llama-3.2-1B"):
        """Load config from pretrained model on HuggingFace."""
        from transformers import AutoConfig
        hf_config = AutoConfig.from_pretrained(model_name)

        return cls(
            vocab_size=hf_config.vocab_size,
            hidden_size=hf_config.hidden_size,
            intermediate_size=hf_config.intermediate_size,
            num_hidden_layers=hf_config.num_hidden_layers,
            num_attention_heads=hf_config.num_attention_heads,
            num_key_value_heads=hf_config.num_key_value_heads,
            max_position_embeddings=hf_config.max_position_embeddings,
            rope_theta=hf_config.rope_theta,
            rope_scaling=hf_config.rope_scaling,  # Include rope_scaling
            rms_norm_eps=hf_config.rms_norm_eps,
            hidden_act=hf_config.hidden_act,
        )
