"""
Model presets used across experiments.

Provides a small registry to map short names to Hugging Face IDs and
optionally known layer counts. Layer count will be auto-detected after
loading if not provided.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Dict


@dataclass
class ModelPreset:
    name: str
    hf_id: str
    num_layers: Optional[int] = None


PRESETS: Dict[str, ModelPreset] = {
    # LLaMA 3.1 8B
    "llama3.1": ModelPreset(
        name="llama3.1",
        hf_id="meta-llama/Meta-Llama-3.1-8B",
        num_layers=None,  # auto-detect
    ),
    # LLaMA 3.2 3B
    "llama3.2": ModelPreset(
        name="llama3.2",
        hf_id="meta-llama/Llama-3.2-3B",
        num_layers=None,
    ),
    # Qwen 2.5 3B
    "qwen2.5-3b": ModelPreset(
        name="qwen2.5-3b",
        hf_id="Qwen/Qwen2.5-3B",
        num_layers=None,
    ),
    # Qwen 2.5 7B (Instruct)
    "qwen2.5-7b": ModelPreset(
        name="qwen2.5-7b",
        hf_id="Qwen/Qwen2.5-7B-Instruct",
        num_layers=None,
    ),
    # Qwen3 1.7B (for motivating example)
    "qwen3-1.7b": ModelPreset(
        name="qwen3-1.7b",
        hf_id="Qwen/Qwen3-1.7B",
        num_layers=28,  # Qwen3-1.7B has 28 layers
    ),
}


def get_preset(name: str) -> ModelPreset:
    key = name.strip().lower().replace(" ", "")
    if key in PRESETS:
        return PRESETS[key]
    # allow aliases
    aliases = {
        "qwen2.5": "qwen2.5-3b",
        "qwen-2.5-3b": "qwen2.5-3b",
        "llama3.2-3b": "llama3.2",
        "llama3.1-8b": "llama3.1",
    }
    if key in aliases and aliases[key] in PRESETS:
        return PRESETS[aliases[key]]
    raise KeyError(f"Unknown model preset: {name}")
