"""Constants for Flower Language Model (LLM) configurations."""

from enum import Enum
from types import MappingProxyType


class Tokenizers(Enum):
    """Tokenizers templates."""

    SMOLLM_SHARED = MappingProxyType({
        "name": "HuggingFaceTB/SmolLM-1.7B",
        "kwargs": {
            "max_seq_len": 2048,
        },
    })

    COSMO = MappingProxyType({
        "name": "~/anonymous/projects/repo/trained_tokenizers/tokenizer_cosmo_",
        "kwargs": {
            "max_seq_len": 2048,
        },
    })

    INFINIMATH = MappingProxyType({
        "name": "~/anonymous/projects/repo/"
        "trained_tokenizers/tokenizer_infiwebmath-4plus_",
        "kwargs": {
            "max_seq_len": 2048,
        },
    })


class ModelConfig(Enum):
    """Model configurations templates."""

    MPT_CAUSAL_LM_125M = MappingProxyType(
        {
            "name": "mpt_causal_lm",
            "init_device": "cpu",
            "d_model": 768,
            "n_heads": 12,
            "n_layers": 12,
            "expansion_ratio": 4,
            "max_seq_len": 2048,
            "vocab_size": 50368,
            "attn_config": {
                "attn_impl": "torch",
            },
            "output_hidden_states": True,
        },
    )

    MPT_CAUSAL_LM_3B = MappingProxyType(
        {
            "name": "mpt_causal_lm",
            "init_device": "cpu",
            "d_model": 2560,
            "n_heads": 20,
            "n_layers": 32,
            "expansion_ratio": 4,
            "max_seq_len": 2048,
            "vocab_size": 50368,
            "attn_config": {
                "attn_impl": "torch",
            },
            "output_hidden_states": True,
        },
    )

    MPT_CAUSAL_LM_350M = MappingProxyType(
        {
            "name": "mpt_causal_lm",
            "init_device": "cpu",
            "d_model": 1024,
            "n_heads": 16,
            "n_layers": 24,
            "expansion_ratio": 4,
            "max_seq_len": 2048,
            "vocab_size": 50368,
            "attn_config": {
                "attn_impl": "torch",
            },
            "output_hidden_states": True,
        },
    )

    SMOLLM_135M = MappingProxyType({
        "name": "mpt_causal_lm",
        "d_model": 576,
        "n_heads": 9,
        "n_layers": 30,
        "ffn_hidden_size": 1536,
        "max_seq_len": 2048,
        "vocab_size": 50368,
        "attn_config": {
            "attn_impl": "flash",
            "rope": True,
            "alibi": False,
            "rope_theta": 10000,
        },
        "ffn_config": {
            "ffn_act_fn": {"name": "silu"},
        },
        "init_config": {
            "init_std": 0.041666666666666664,
        },
    })

    SMOLLM_1B = MappingProxyType({
        "name": "mpt_causal_lm",
        "d_model": 2048,
        "n_heads": 32,
        "n_layers": 24,
        "ffn_hidden_size": 8192,
        "max_seq_len": 2048,
        "vocab_size": 50368,
        "attn_config": {
            "attn_impl": "flash",
            "rope": True,
            "alibi": False,
            "rope_theta": 10000,
        },
        "ffn_config": {
            "ffn_act_fn": {"name": "silu"},
        },
        "init_config": {
            "init_std": 0.02,
        },
    })

    SMOLLM_13B = MappingProxyType({
        "name": "mpt_causal_lm",
        "d_model": 5120,
        "n_heads": 40,
        "n_layers": 40,
        "ffn_hidden_size": 20480,
        "max_seq_len": 2048,
        "vocab_size": 50368,
        "attn_config": {
            "attn_impl": "flash",
            "rope": True,
            "alibi": False,
            "rope_theta": 10000,
        },
        "ffn_config": {
            "ffn_act_fn": {"name": "silu"},
        },
        "init_config": {
            "init_std": 0.014,
        },
    })
