"""Tests for model configuration."""

import pytest
from fma_llama.model.config import LlamaConfig


def test_default_config():
    """Test default configuration values."""
    config = LlamaConfig()

    assert config.vocab_size == 128256
    assert config.hidden_size == 2048
    assert config.num_hidden_layers == 16
    assert config.num_attention_heads == 32
    assert config.num_key_value_heads == 8
    assert config.use_fma_attention is True


def test_custom_config():
    """Test custom configuration."""
    config = LlamaConfig(
        hidden_size=1024,
        num_hidden_layers=8,
        use_fma_attention=False,
    )

    assert config.hidden_size == 1024
    assert config.num_hidden_layers == 8
    assert config.use_fma_attention is False


def test_fma_config():
    """Test FMA-specific configuration."""
    config = LlamaConfig(
        fma_block_size=256,
        fma_num_clusters=64,
    )

    assert config.fma_block_size == 256
    assert config.fma_num_clusters == 64
