"""Enum definitions for CLI argument normalization.

Values are sourced from `constants` to avoid magic strings.
"""
from __future__ import annotations

from enum import Enum
import constants


class QuantFormat(Enum):
    """Supported weight / activation quantization formats.

    Members map to string identifiers used in the CLI and downstream logic.
    """
    INT8 = constants.QUANT_INT8
    INT4 = constants.QUANT_INT4
    MXFP4 = constants.QUANT_MXFP4
    # MXFP8E4M3 = constants.QUANT_MXFP8E4M3
    # MXFP8E5M2 = constants.QUANT_MXFP8E5M2

class ObserverType(Enum):
    """Supported observer types for quantization calibration."""
    MINMAX = "minmax"
    PERCENTILE = "percentile"
    KL_DIVERGENCE = "kl_divergence"
    MSE = "mse"

class Granularity(Enum):
    PER_TENSOR = constants.GRANULARITY_PER_TENSOR
    PER_CHANNEL = constants.GRANULARITY_PER_CHANNEL
    GROUP = "group"


class PTQAlg(Enum):
    """Post-training quantization algorithm identifiers."""
    GPTQ = constants.PTQ_GPTQ
    RTN = constants.PTQ_RTN
    PTQ = constants.PTQ_BASIC


class LossFunction(Enum):
    """Loss function identifiers for transform optimization."""
    OUTPUT_DISTILLATION = "output_distill"
    UNEMBED_DISTILLATION = "unembed_distill"
    FLAT_Q_DISTILLATION = "flat_q_distill"


class DistanceMetric(Enum):
    """Distance metric identifiers for transform optimization."""
    KL = "kl"
    MSE = "mse"
    FROBENIUS = "frobenius"
    # Future options can be added here (e.g., L1, COSINE, etc.)


__all__ = [
    "QuantFormat",
    "PTQAlg",
    "LossFunction",
    "DistanceMetric",
]
