import math
from dataclasses import dataclass

transformer_configs = {
    "llama-2-7b": dict(block_size=4096, n_layer=32, n_head=32, dim=4096),
    'llama-2-7b-32k': dict(block_size=32768, n_layer=32, dim= 4096, vocab_size=32000, scaling_factor=8),
    "llama-2-13b": dict(block_size=4096, n_layer=40, n_head=40, dim=5120),
    "llama-2-70b": dict(block_size=4096, n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
    "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
    "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
    "68m": dict(block_size=2048, n_layer=2, n_head=12, n_local_heads=12, dim=768, intermediate_size=3072, vocab_size=32000),
    "tinyllama": dict(block_size =2048, n_layer=22, n_head=32, n_local_heads=4, dim=2048, intermediate_size=5632, vocab_size=32000),
    "llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192),
    "llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192),
    "llama-3.2-1b": dict(block_size=131072, n_layer=16, n_head=32, n_local_heads=8, dim=2048, intermediate_size=8192, vocab_size=128256, rope_base=500000.0, scaling_factor=32, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192),
    "Qwen2.5-7b": dict(block_size=131072, n_layer=28, n_head=28, n_local_heads=4, dim=3584, intermediate_size=18944, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Qwen2.5-14b": dict(block_size=131072, n_layer=48, n_head=40, n_local_heads=8, dim=5120, intermediate_size=13824, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Qwen2.5-32b": dict(block_size=131072, n_layer=64, n_head=40, n_local_heads=8, dim=5120, intermediate_size=27648, vocab_size=152064, rope_base=1000000.0, qkv_bias=True, norm_eps=1e-6),
    "Yi-1.5-6b": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=4, dim=4096, intermediate_size=11008, vocab_size=64000, rope_base=500000.0),
    "Yi-1.5-34b-32k": dict(block_size=32768, n_layer=60, n_head=56, n_local_heads=8, dim=7168, intermediate_size=20480, vocab_size=64000, rope_base=500000.0),
    "Mistral-7B-v0.1": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
    "Mistral-7B-v0.3": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32768, rope_base=1000000.0),
    
    # New models
    "Llama-3-8B-Instruct-262k": dict(block_size=262144, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=283461213.0, norm_eps=1e-5),
    "DeepSeek-R1-Distill-Qwen-1.5B": dict(block_size=131072, n_layer=28, n_head=12, n_local_heads=2, dim=1536, intermediate_size=8960, vocab_size=151936, rope_base=10000.0, qkv_bias=True, norm_eps=1e-6),
    "DeepSeek-R1-Distill-Qwen-7B": dict(block_size=131072, n_layer=28, n_head=28, n_local_heads=4, dim=3584, intermediate_size=18944, vocab_size=152064, rope_base=10000.0, qkv_bias=True, norm_eps=1e-6),
    "DeepSeek-R1-Distill-Llama-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0, scaling_factor=8, high_freq_factor=4, low_freq_factor=1, original_max_position_embeddings=8192, norm_eps=1e-5),
    "Qwen3-0.6B": dict(block_size=40960, n_layer=28, n_head=16, n_local_heads=8, head_dim=128, dim=1024, intermediate_size=3072, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
    "Qwen3-1.7B": dict(block_size=40960, n_layer=28, n_head=16, n_local_heads=8, dim=2048, intermediate_size=6144, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
    "Qwen3-8B": dict(block_size=40960, n_layer=36, n_head=32, n_local_heads=8, dim=4096, intermediate_size=12288, vocab_size=151936, rope_base=1000000.0, norm_eps=1e-6, qk_norm=True),
    
    # Drafter models
    "longspec-Llama-3-8B-Instruct-262k": dict(block_size=262144, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=283461213.0, norm_eps=1e-5),
    "EAGLE3-DeepSeek-R1-Distill-LLaMA-8B": dict(block_size=2048, n_layer=1, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, draft_vocab_size=32000, norm_eps=1e-5),
}

@dataclass
class ModelArgs:
    block_size: int = 2048
    vocab_size: int = 32000
    n_layer: int = 32
    n_head: int = 32
    dim: int = 4096
    intermediate_size: int = None
    n_local_heads: int = -1
    head_dim: int = -1
    rope_base: float = 10000
    norm_eps: float = 1e-5
    scaling_factor:float = 1.0
    low_freq_factor: int = None
    high_freq_factor: int = None
    original_max_position_embeddings: int = None
    qkv_bias: bool = False
    qk_norm: bool = False
    draft_vocab_size: int = 32000

    def __post_init__(self):
        if self.n_local_heads == -1:
            self.n_local_heads = self.n_head
        if self.intermediate_size is None:
            hidden_dim = 4 * self.dim
            n_hidden = int(2 * hidden_dim / 3)
            self.intermediate_size = n_hidden if n_hidden % 256 == 0 else n_hidden + 256 - (n_hidden % 256) # round to nearest multiple of 256
        if self.head_dim == -1:
            self.head_dim = self.dim // self.n_head

    @classmethod
    def from_name(cls, name: str):
        if name in transformer_configs:
            return cls(**transformer_configs[name])
        # fuzzy search
        config = [config for config in transformer_configs if config.lower() in str(name).lower()]
        # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
        # take longer name (as it have more symbols matched)
        if len(config) > 1:
            config.sort(key=len, reverse=True)
            assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
        print(config)
        return cls(**transformer_configs[config[0]])

@dataclass
class LoRAConfig:
    rank: int = 16
    alpha: float = 32
    lora_bias: bool = False
    use_rslora: bool = False
    
    def __post_init__(self):
        if self.use_rslora:
            self.lora_scaling = self.alpha / math.sqrt(self.rank)
        else:
            self.lora_scaling = self.alpha / self.rank
    
    @classmethod
    def from_dict(cls, config: dict):
        return cls(**config)