from dataclasses import dataclass

@dataclass
class ModelArgs:
    block_size: int = 2048
    vocab_size: int = 32000
    n_layer: int = 32
    n_head: int = 32
    dim: int = 4096
    intermediate_size: int = None
    n_local_heads: int = -1
    head_dim: int = -1
    rope_base: float = 10000
    norm_eps: float = 1e-5
    scaling_factor:float = 1.0
    low_freq_factor: int = None # For Llama3.1
    high_freq_factor: int = None # For Llama3.1
    original_max_position_embeddings: int = None # For Llama3.1
    qkv_bias: bool = False # For Qwen3
    qk_norm: bool = False # For Qwen3

    def __post_init__(self):
        if self.n_local_heads == -1:
            self.n_local_heads = self.n_head
        if self.intermediate_size is None:
            hidden_dim = 4 * self.dim
            n_hidden = int(2 * hidden_dim / 3)
            self.intermediate_size = n_hidden if n_hidden % 256 == 0 else n_hidden + 256 - (n_hidden % 256) # round to nearest multiple of 256
        if self.head_dim == -1:
            self.head_dim = self.dim // self.n_head

    @classmethod
    def from_name(cls, name: str):
        if name in transformer_configs:
            return cls(**transformer_configs[name])
        # fuzzy search
        config = [config for config in transformer_configs if config.lower() in str(name).lower()]
        # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
        # take longer name (as it have more symbols matched)
        if len(config) > 1:
            config.sort(key=len, reverse=True)
            assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
        print(f"Found {len(config)} configs matched to {name}: {config}")
        return cls(**transformer_configs[config[0]])

transformer_configs = {
    "llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192),
    "llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192),
    "Qwen2.5-7b": dict(block_size=131072, n_layer=28, n_head=28, n_local_heads=4, dim=3584, intermediate_size=18944, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Qwen2.5-14b": dict(block_size=131072, n_layer=48, n_head=40, n_local_heads=8, dim=5120, intermediate_size=13824, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Qwen2.5-32b": dict(block_size=131072, n_layer=64, n_head=40, n_local_heads=8, dim=5120, intermediate_size=27648, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Mistral-7B-v0.3": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32768, rope_base=1000000.0),
    "DeepSeek-R1-Distill-Qwen-1.5B": dict(block_size=131072, n_layer=28, n_head=12, n_local_heads=2, dim=1536, intermediate_size=8960, vocab_size=151936, rope_base=10000.0, qkv_bias=True, norm_eps=1e-6),
    "DeepSeek-R1-Distill-Qwen-7B": dict(block_size=131072, n_layer=28, n_head=28, n_local_heads=4, dim=3584, intermediate_size=18944, vocab_size=152064, rope_base=10000.0, qkv_bias=True, norm_eps=1e-6),
    "DeepSeek-R1-Distill-Llama-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192, norm_eps=1e-5),
    "Qwen3-0.6B": dict(block_size=40960, n_layer=28, n_head=16, n_local_heads=8, head_dim=128, dim=1024, intermediate_size=3072, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
    "Qwen3-1.7B": dict(block_size=40960, n_layer=28, n_head=16, n_local_heads=8, dim=2048, intermediate_size=6144, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
    "Qwen3-8B": dict(block_size=40960, n_layer=36, n_head=32, n_local_heads=8, dim=4096, intermediate_size=12288, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
}