"""Shared string constants and choice collections for the mixed-precision
quantization CLI.

All small literal strings that may repeat across modules live here to avoid
magic strings scattered throughout the codebase.
"""
from __future__ import annotations

MODEL_KEY = "MODEL_KEY"
MODEL_ID = "MODEL_ID"

# Device selection -----------------------------------------------------------
DEFAULT_DEVICE: str = "auto"
DEVICE_AUTO: str = DEFAULT_DEVICE  # alias for clarity / compatibility
DEVICE_CUDA: str = "cuda"
DEVICE_MPS: str = "mps"
DEVICE_CPU: str = "cpu"

# Quantization format identifiers -------------------------------------------
QUANT_INT8: str = "int8"
QUANT_INT4: str = "int4"
QUANT_MXFP4: str = "fp4_e2m1"
QUANT_MXFP8E4M3: str = "fp8_e4m3"
QUANT_MXFP8E5M2: str = "fp8_e5m2"
QUANT_SCALE_E8M0: str = "e8m0"
QUANT_SCALE_E4M3: str = "e4m3"

QUANT_FORMAT_CHOICES: tuple[str, ...] = (
    QUANT_INT8,
    QUANT_INT4,
    QUANT_MXFP4,
    QUANT_MXFP8E4M3,
    QUANT_MXFP8E5M2,
)

# Mixed-precision algorithm identifiers -------------------------------------
MP_PER_CHANNEL: str = "mp_per_channel"
MP_GLOBAL: str = "mp"
MP_NONE: str = "none"

MP_ALG_CHOICES: tuple[str, ...] = (
    MP_PER_CHANNEL,
    MP_GLOBAL,
    MP_NONE,
)

# PTQ algorithm identifiers --------------------------------------------------
PTQ_GPTQ: str = "gptq"
PTQ_RTN: str = "rtn"
PTQ_BASIC: str = "ptq"

PTQ_ALGO_CHOICES: tuple[str, ...] = (
    PTQ_GPTQ,
    PTQ_RTN,
)

# Dtype canonical strings + aliases -----------------------------------------
DTYPE_FP32: str = "float32"
DTYPE_FP16: str = "float16"
DTYPE_BF16: str = "bfloat16"

# Common short aliases used in CLI environments.
DTYPE_ALIAS_FP32: str = "fp32"
DTYPE_ALIAS_FP16: str = "fp16"
DTYPE_ALIAS_BF16: str = "bf16"

DTYPE_CANONICAL_CHOICES: tuple[str, ...] = (
    DTYPE_FP32,
    DTYPE_FP16,
    DTYPE_BF16,
)
DTYPE_ALL_CHOICES: tuple[str, ...] = (
    DTYPE_FP32,
    DTYPE_ALIAS_FP32,
    DTYPE_FP16,
    DTYPE_ALIAS_FP16,
    DTYPE_BF16,
    DTYPE_ALIAS_BF16,
)

# Model family substrings (used for architecture-specific handling) ----------
MODEL_FAMILY_LLAMA: str = "llama"
MODEL_FAMILY_QWEN: str = "qwen"

# New: root/decoder attribute names for HF style models --------------------
MODEL_ROOT_ATTR: str = "model"
MODEL_DECODER_ATTR: str = "decoder"

# New: Layer attribute name constants (centralized strings) ---------------
LAYER_EMBED_TOKENS: str = "embed_tokens"
LAYER_EMBED_POSITIONS: str = "embed_positions"
LAYER_PROJECT_OUT: str = "project_out"
LAYER_PROJECT_IN: str = "project_in"
LAYER_NORM: str = "norm"

# Grouped tuples for family-specific handling
LLAMA_LAYER_ATTRS: tuple[str, ...] = (
    LAYER_EMBED_TOKENS,
    LAYER_NORM,
)

# Layer/Tensor attributes ---------------------------------------------------
ATTR_WEIGHT: str = "weight"
ATTR_INPUT: str = "input"
ATTR_OUTPUT: str = "output"
ATTR_SCALE: str = "scale"
ATTR_ZERO_POINT: str = "zero_point"

# Layer types ---------------------------------------------------------------
LAYER_TYPE_LINEAR: str = "Linear"
LAYER_TYPE_CONV2D: str = "Conv2d"

# Quantization schemes ------------------------------------------------------
QUANT_SCHEME_SYMMETRIC: str = "symmetric"
QUANT_SCHEME_ASYMMETRIC: str = "asymmetric"

# Granularity ---------------------------------------------------------------
GRANULARITY_PER_CHANNEL: str = "per_channel"
GRANULARITY_PER_TENSOR: str = "per_tensor"

# Evaluation prefixes -------------------------------------------------------
EVAL_PREFIX_FLOAT: str = "Float"
EVAL_PREFIX_QUANTIZED: str = "Quantized"

# Special tokens/strings ----------------------------------------------------
THINKING_TOKEN: str = "thinking"
CALIBRATION_PHASE: str = "calibration"

# LM Eval configuration keys ------------------------------------------------
LMEVAL_ADD_BOS_TOKEN: str = "add_bos_token"
LMEVAL_MAX_LENGTH: str = "max_length"
LMEVAL_TASKS: str = "tasks"
LMEVAL_APPLY_CHAT_TEMPLATE: str = "apply_chat_template"
LMEVAL_NUM_FEWSHOT: str = "num_fewshot"
LMEVAL_FEWSHOT_AS_MULTITURN: str = "fewshot_as_multiturn"
LMEVAL_GEN_KWARGS: str = "gen_kwargs"
LMEVAL_MAX_GEN_TOKS: str = "max_gen_toks"

# Tokenizer keys ------------------------------------------------------------
TOKENIZER_INPUT_IDS: str = "input_ids"
TOKENIZER_ATTENTION_MASK: str = "attention_mask"
TOKENIZER_RETURN_TENSORS: str = "return_tensors"
TOKENIZER_ADD_SPECIAL_TOKENS: str = "add_special_tokens"
TOKENIZER_PT: str = "pt"  # PyTorch tensor type for return_tensors

# Model configuration keys --------------------------------------------------
CONFIG_HIDDEN_SIZE: str = "hidden_size"
CONFIG_NUM_ATTENTION_HEADS: str = "num_attention_heads"
CONFIG_NUM_KEY_VALUE_HEADS: str = "num_key_value_heads"
CONFIG_ATTENTION_BIAS: str = "attention_bias"
CONFIG_HEAD_DIM: str = "head_dim"
CONFIG_TORCH_DTYPE: str = "torch_dtype"

# Block module map keys (for quantization module tracking) ------------------
MODULE_QUANTIZED_ATTN: str = "quantized_attn"
MODULE_QUANTIZED_MLP: str = "quantized_mlp"
MODULE_QKV_IN_TRANSFORM: str = "qkv_in_transform"
MODULE_O_IN_TRANSFORM: str = "o_in_transform"
MODULE_GATE_UP_IN_TRANSFORM: str = "gate_up_in_transform"
MODULE_DOWN_IN_TRANSFORM: str = "down_in_transform"
MODULE_V_OUT_TRANSFORM: str = "v_out_transform"

# Dataset identifiers -------------------------------------------------------
WIKITEXT: str = "wikitext"
WIKITEXT2: str = "wikitext2"
FINEWEB: str = "fineweb"
FINEWEB_EDU: str = "fineweb-edu"


# Attribute names ---------------------------------------------------------
ATTR_MODEL = "model"
ATTR_DECODER = "decoder"
ATTR_LAYERS = "layers"
ATTR_BLOCKS = "blocks"
ATTR_FEATURES = "features"
ATTR_TRANSFORMER = "transformer"
ATTR_H = "h"
ATTR_ENCODER = "encoder"
ATTR_LAYER = "layer"
ATTR_BERT = "bert"
ATTR_CONFIG = "config"
ATTR_ARCHITECTURES = "architectures"

__all__ = [
    "DEFAULT_DEVICE",
    "DEVICE_AUTO",
    "DEVICE_CUDA",
    "DEVICE_MPS",
    "DEVICE_CPU",
    # quant formats
    "QUANT_INT8",
    "QUANT_INT4",
    "QUANT_MXFP4",
    "QUANT_MXFP8E4M3",
    "QUANT_MXFP8E5M2",
    "QUANT_FORMAT_CHOICES",
    # mixed precision algos
    "MP_PER_CHANNEL",
    "MP_GLOBAL",
    "MP_NONE",
    "MP_ALG_CHOICES",
    # ptq algos
    "PTQ_GPTQ",
    "PTQ_RTN",
    "PTQ_BASIC",
    "PTQ_ALGO_CHOICES",
    # dtypes
    "DTYPE_FP32",
    "DTYPE_FP16",
    "DTYPE_BF16",
    "DTYPE_ALIAS_FP32",
    "DTYPE_ALIAS_FP16",
    "DTYPE_ALIAS_BF16",
    "DTYPE_CANONICAL_CHOICES",
    "DTYPE_ALL_CHOICES",
    # model families
    "MODEL_FAMILY_LLAMA",
    "MODEL_FAMILY_QWEN",
    # new root/decoder attr exports
    "MODEL_ROOT_ATTR",
    "MODEL_DECODER_ATTR",
    # layer attrs
    "LAYER_EMBED_TOKENS",
    "LAYER_EMBED_POSITIONS",
    "LAYER_PROJECT_OUT",
    "LAYER_PROJECT_IN",
    "LAYER_NORM",
    "LLAMA_LAYER_ATTRS",
    # datasets
    "WIKITEXT",
    "WIKITEXT2",
    "FINEWEB",
    "FINEWEB_EDU",
    # attribute names
    "ATTR_MODEL",
    "ATTR_DECODER",
    "ATTR_LAYERS",
    "ATTR_BLOCKS",
    "ATTR_FEATURES",
    "ATTR_TRANSFORMER",
    "ATTR_H",
    "ATTR_ENCODER",
    "ATTR_LAYER",
    "ATTR_BERT",
    "ATTR_CONFIG",
    "ATTR_ARCHITECTURES",
    # new layer/tensor attributes
    "ATTR_WEIGHT",
    "ATTR_INPUT",
    "ATTR_OUTPUT",
    "ATTR_SCALE",
    "ATTR_ZERO_POINT",
    # layer types
    "LAYER_TYPE_LINEAR",
    "LAYER_TYPE_CONV2D",
    # quantization schemes
    "QUANT_SCHEME_SYMMETRIC",
    "QUANT_SCHEME_ASYMMETRIC",
    # granularity
    "GRANULARITY_PER_CHANNEL",
    "GRANULARITY_PER_TENSOR",
    # evaluation prefixes
    "EVAL_PREFIX_FLOAT",
    "EVAL_PREFIX_QUANTIZED",
    # special tokens/strings
    "THINKING_TOKEN",
    "CALIBRATION_PHASE",
    # LM eval configuration keys
    "LMEVAL_ADD_BOS_TOKEN",
    "LMEVAL_MAX_LENGTH",
    "LMEVAL_TASKS",
    "LMEVAL_APPLY_CHAT_TEMPLATE",
    "LMEVAL_NUM_FEWSHOT",
    "LMEVAL_FEWSHOT_AS_MULTITURN",
    "LMEVAL_GEN_KWARGS",
    "LMEVAL_MAX_GEN_TOKS",
    # Tokenizer keys
    "TOKENIZER_INPUT_IDS",
    "TOKENIZER_ATTENTION_MASK",
    "TOKENIZER_RETURN_TENSORS",
    "TOKENIZER_ADD_SPECIAL_TOKENS",
    # Model configuration keys
    "CONFIG_HIDDEN_SIZE",
    "CONFIG_NUM_ATTENTION_HEADS",
    "CONFIG_NUM_KEY_VALUE_HEADS",
    "CONFIG_ATTENTION_BIAS",
    "CONFIG_HEAD_DIM",
    "CONFIG_TORCH_DTYPE",
    # module names
    "MODULE_QUANTIZED_ATTN",
    "MODULE_QUANTIZED_MLP",
    "MODULE_QKV_IN_TRANSFORM",
    "MODULE_O_IN_TRANSFORM",
    "MODULE_GATE_UP_IN_TRANSFORM",
    "MODULE_DOWN_IN_TRANSFORM",
    "MODULE_V_OUT_TRANSFORM",
]
