"""CLI argument parsing for the MX LLM quantization project.

This module provides command-line argument parsing for configuring LLM quantization
experiments with MX formats. The main entry point is `parse_args()`, which returns
typed configuration dataclasses for runtime, quantization, and optimization settings.

The arguments are organized into the following sections:
- Model and Dataset Configuration
- Data Type and Format Configuration
- Quantization Configuration
- Transform Configuration (for learned transforms)
- PTQ Algorithm Configuration (GPTQ-specific)
- Transform Optimization Configuration
- Evaluation Configuration
- Runtime and Hardware Configuration
- Logging and Debugging Configuration
"""
from __future__ import annotations

import argparse
from typing import Optional

import constants
from enums import QuantFormat, PTQAlg, ObserverType, LossFunction, DistanceMetric
from evaluate.benchmark_constants import WIKITEXT2, ALL_LM_EVAL_TASKS
from models import setup_device, resolve_dtype
from quantization.quant_args import QuantizationGranularity
from quantization.transforms.transforms import TRANSFORMS
from quantization.quant_config import QuantConfig
from run_config import RunConfig
from transform_optimization.opt_config import OptimizationConfig


def parse_args(argv: Optional[list[str]] = None) -> tuple[RunConfig, QuantConfig, Optional[OptimizationConfig]]:
    """Parse CLI arguments into RunConfig, QuantConfig, and OptimizationConfig instances.

    Parameters
    ----------
    argv : Optional[list[str]]
        Optional explicit argument list (mainly for testing). If omitted,
        argparse reads from sys.argv.

    Returns
    -------
    tuple[RunConfig, QuantConfig, Optional[OptimizationConfig]]
        A tuple containing:
        - RunConfig: Runtime configuration (model, dataset, device, etc.)
        - QuantConfig: Quantization configuration (formats, algorithms, etc.)
        - OptimizationConfig or None: Transform optimization config (only if using learned transforms)
    """
    parser = argparse.ArgumentParser(
        description="LATMiX: LLM MX quantization with learnable transform",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # ========================================================================
    # Model and Dataset Configuration
    # ========================================================================
    parser.add_argument(
        "--model_name", "-m",
        type=str,
        required=True,
        help="Model identifier from HuggingFace Hub (e.g. 'meta-llama/Llama-2-7b-hf').",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=None,
        help="Path to load model from local directory (used with --use_model_path).",
    )
    parser.add_argument(
        "--use_model_path",
        action="store_true",
        help="Load model from local path specified in --model_path instead of HuggingFace Hub.",
    )
    parser.add_argument(
        "--calibration_dataset",
        type=str,
        required=True,
        help="Dataset name/path for calibration statistics collection.",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        help="Path to load dataset from local directory (used with --use_dataset_path).",
    )
    parser.add_argument(
        "--use_dataset_path",
        action="store_true",
        help="Load dataset from local path specified in --dataset_path instead of HuggingFace Hub.",
    )
    parser.add_argument(
        "--n_samples",
        type=int,
        default=256,
        help="Number of calibration samples to use.",
    )
    parser.add_argument(
        "--shuffle_calibration",
        action="store_true",
        help="Shuffle calibration dataset before sampling.",
    )
    parser.add_argument(
        "--calib_sequence_length",
        type=int,
        default=1024,
        help="Sequence length for calibration samples.",
    )

    # ========================================================================
    # Data Type and Format Configuration
    # ========================================================================
    parser.add_argument(
        "--base_dtype",
        type=str,
        choices=constants.DTYPE_ALL_CHOICES,
        default=constants.DTYPE_BF16,
        help="Base data type for model execution.",
    )
    parser.add_argument(
        "--weight_quant_format", "-w",
        type=str,
        choices=[f.value for f in QuantFormat],
        default=QuantFormat.MXFP4.value,
        help="Weight quantization format.",
    )
    parser.add_argument(
        "--act_quant_format", "-a",
        type=str,
        choices=[f.value for f in QuantFormat],
        default=QuantFormat.MXFP4.value,
        help="Activation quantization format.",
    )
    parser.add_argument(
        "--w_bits",
        type=int,
        default=4,
        help="Weight quantization bitwidth (for non-MX formats).",
    )
    parser.add_argument(
        "--a_bits",
        type=int,
        default=4,
        help="Activation quantization bitwidth (for non-MX formats).",
    )

    # ========================================================================
    # Quantization Configuration
    # ========================================================================
    parser.add_argument(
        "--ptq_alg",
        type=str,
        choices=list(constants.PTQ_ALGO_CHOICES),
        default=constants.PTQ_RTN,
        help="Post-training quantization algorithm.",
    )
    parser.add_argument(
        "--weight_granularity",
        type=str,
        choices=[g.value for g in QuantizationGranularity],
        default=QuantizationGranularity.GROUP.value,
        help="Weight quantization granularity.",
    )
    parser.add_argument(
        "--activation_granularity",
        type=str,
        choices=[g.value for g in QuantizationGranularity],
        default=QuantizationGranularity.GROUP.value,
        help="Activation quantization granularity.",
    )
    parser.add_argument(
        "--mx-block-size",
        type=int,
        default=32,
        help="Block size for MX formats (grouping granularity).",
    )
    parser.add_argument(
        "--symmetric",
        type=bool,
        default=True,
        help="Use symmetric quantization.",
    )
    parser.add_argument(
        "--disable_activation_calibration",
        action="store_false",
        dest="calibrate_activations",
        help="Skip activation calibration step.",
    )
    parser.add_argument(
        "--observer_type",
        type=str,
        default=ObserverType.MINMAX.value,
        choices=[o.value for o in ObserverType],
        help="Observer type for activation quantization calibration.",
    )

    # ========================================================================
    # Transform Configuration
    # ========================================================================
    parser.add_argument(
        "--transform_class",
        type=str,
        default="identity",
        choices=TRANSFORMS.keys(),
        help="Transform class to apply to all layers (overridden by r1/r2 if specified).",
    )
    parser.add_argument(
        "--transform_class_r1",
        type=str,
        default="identity",
        choices=TRANSFORMS.keys(),
        help="Transform class for the first layer/block (R1).",
    )
    parser.add_argument(
        "--transform_class_r2",
        type=str,
        default="identity",
        choices=TRANSFORMS.keys(),
        help="Transform class for remaining layers/blocks (R2).",
    )
    parser.add_argument(
        "--matrix_init",
        type=str,
        default="orthogonal",
        choices=["orthogonal", "identity", "hadamard", "random"],
        help="Initialization method for learned transform matrices.",
    )
    parser.add_argument(
        "--mat_param",
        type=str,
        default="learnable_inv",
        choices=["learnable_inv", "learnable_qr", "learnable_kronecker"],
        help="Matrix parametrization method for learned transforms.",
    )
    parser.add_argument(
        "--hadamard_group_size",
        type=int,
        default=128,
        help="Group size for Hadamard transform.",
    )
    parser.add_argument(
        "--block_transform_matrix",
        action="store_true",
        help="Use a block-wise instead of single shared transform matrix across all layers.",
    )
    parser.add_argument(
        "--disable_block_diag_init",
        action="store_true",
        help="Use block diagonal initialization for transform matrices.",
    )

    # ========================================================================
    # PTQ Algorithm Configuration (GPTQ-specific)
    # ========================================================================
    parser.add_argument(
        "--quantization_order",
        type=str,
        default="default",
        choices=["default", "activation"],
        help="Weight quantization order in GPTQ algorithm.",
    )
    parser.add_argument(
        "--rel_damp",
        type=float,
        default=1e-2,
        help="Relative dampening factor for GPTQ Hessian matrix.",
    )
    parser.add_argument("--amp",
                        action="store_true",
                        help="whether to enable fp16 autocasting."
    )
    # ========================================================================
    # Transform Optimization Configuration
    # ========================================================================
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-5,
        help="Learning rate for transform optimization.",
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.05,
        help="Weight decay for AdamW optimizer during transform optimization.",
    )
    parser.add_argument(
        "--betas",
        nargs=2,
        type=float,
        default=[0.90, 0.98],
        help="Beta1 and beta2 parameters for AdamW optimizer.",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=1000,
        help="Maximum number of optimization steps for transform training.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        choices=["cosine", "linear", "none"],
        default="cosine",
        help="Learning rate scheduler type.",
    )
    parser.add_argument(
        "--warmup_iters",
        type=int,
        default=10,
        help="Percentage of warmup iterations for learning rate scheduler.",
    )
    parser.add_argument(
        "--warmup_start_factor",
        type=float,
        default=0.1,
        help="Starting learning rate factor for warmup (lr_start = lr * warmup_start_factor).",
    )
    parser.add_argument(
        "--loss_function",
        type=str,
        choices=[e.value for e in LossFunction],
        default=LossFunction.OUTPUT_DISTILLATION.value,
        help="Loss function for transform optimization.",
    )
    parser.add_argument(
        "--distance_metric",
        type=str,
        choices=[e.value for e in DistanceMetric],
        default=DistanceMetric.KL.value,
        help="Distance metric for loss computation.",
    )
    parser.add_argument(
        "--loss_t",
        type=float,
        default=1.0,
        help="Temperature parameter for KL divergence loss.",
    )
    parser.add_argument(
        "--reg_lambda",
        type=float,
        default=1e-2,
        help="Regularization coefficient (lambda).",
    )
    parser.add_argument(
        "--disable_rand_noise",
        action="store_true",
        default=False,
        help="Add random noise during transform optimization.",
    )

    # ========================================================================
    # Evaluation Configuration
    # ========================================================================
    parser.add_argument(
        "--eval_tasks",
        nargs="+",
        type=str,
        choices=ALL_LM_EVAL_TASKS,
        default=[WIKITEXT2],
        help="Evaluation tasks to run (e.g. 'wikitext2', 'mmlu').",
    )
    parser.add_argument(
        "--num_fewshot",
        type=int,
        default=0,
        help="Number of few-shot examples for evaluation tasks.",
    )
    parser.add_argument(
        "--run_float_eval",
        action="store_true",
        help="Run evaluation on the unquantized (float) model.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="Batch size for evaluation.",
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        default=None,
        help="Maximum number of samples for evaluation (None = all).",
    )
    parser.add_argument(
        "--disable_thinking",
        action="store_true",
        help="Disable thinking mode for Qwen3 models.",
    )

    # ========================================================================
    # Runtime and Hardware Configuration
    # ========================================================================
    parser.add_argument(
        "--device", "-d",
        type=str,
        default=constants.DEFAULT_DEVICE,
        help="Execution device ('cuda', 'cpu', or 'auto').",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="Number of dataloader workers.",
    )

    args = parser.parse_args(argv)

    # Convert 'id' to 'identity' for internal use
    matrix_init = "identity" if args.matrix_init == "id" else args.matrix_init

    quant_config = QuantConfig(
        weight_q_format=QuantFormat(args.weight_quant_format),
        activation_q_format=QuantFormat(args.act_quant_format),
        weight_granularity=QuantizationGranularity(args.weight_granularity),
        activation_granularity=QuantizationGranularity(args.activation_granularity),
        ptq_alg=PTQAlg(args.ptq_alg),
        symmetric=args.symmetric,
        group_size=args.mx_block_size,
        calibrate_activations=args.calibrate_activations,
        observer_type=args.observer_type,
        hadamard_group_size=args.hadamard_group_size,
        transform_class=args.transform_class,
        transform_class_r1=args.transform_class_r1,
        transform_class_r2=args.transform_class_r2,
        matrix_init=matrix_init,
        w_bits=args.w_bits,
        a_bits=args.a_bits,
        quantization_order=args.quantization_order,
        rel_damp=args.rel_damp,
        amp=args.amp,
    )

    run_config = RunConfig(
        model_name=args.model_name,
        model_path=args.model_path,
        use_model_path=args.use_model_path,
        calibration_dataset=args.calibration_dataset,
        dataset_path=args.dataset_path,
        use_dataset_path=args.use_dataset_path,
        seed=args.seed,
        n_samples=args.n_samples,
        shuffle_calibration=args.shuffle_calibration,
        device=setup_device(args.device),
        base_dtype=resolve_dtype(args.base_dtype),
        run_float_eval=args.run_float_eval,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        max_samples=args.max_samples,
        calib_sequence_length=args.calib_sequence_length,
        eval_tasks=args.eval_tasks,
        num_fewshot=args.num_fewshot,
        disable_thinking=args.disable_thinking,
    )

    # Initialize OptimizationConfig only if using learned transforms
    opt_config = None
    if args.learning_rate in ["learned", "learned_affine"] or args.transform_class_r1 in ["learned", "learned_affine"] or args.transform_class_r2 in ["learned", "learned_affine"]:
        opt_config = OptimizationConfig(
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            betas=tuple(args.betas),
            max_steps=args.max_steps,
            lr_scheduler=args.lr_scheduler if args.lr_scheduler != "none" else None,
            warmup_iters=args.warmup_iters,
            warmup_start_factor=args.warmup_start_factor,
            loss_function=args.loss_function,
            distance_metric=args.distance_metric,
            reg_lambda=args.reg_lambda,
            single_transform_matrix=not args.block_transform_matrix,
            block_diag_init=not args.disable_block_diag_init,
            add_rand_noise=not args.disable_rand_noise,
            temperature=args.loss_t,
            mat_param=args.mat_param
        )

    return run_config, quant_config, opt_config


def print_configs(run_config: RunConfig, quant_config: QuantConfig, opt_config: Optional[OptimizationConfig] = None) -> None:
    """Print all configurations in an elegant, readable format.

    Args:
        run_config: Runtime configuration
        quant_config: Quantization configuration
        opt_config: Optional optimization configuration (for learned transforms)
    """
    # Print run configuration
    print("\n" + str(run_config))

    # Print quantization configuration
    print("\n╔═══════════════════════════════════════════════════════════════╗")
    print("║                QUANTIZATION CONFIGURATION                     ║")
    print("╠═══════════════════════════════════════════════════════════════╣")
    print(f"║ Weight Format:        {quant_config.weight_q_format.value:<40} ║")
    print(f"║ Activation Format:    {quant_config.activation_q_format.value:<40} ║")
    print(f"║ PTQ Algorithm:        {quant_config.ptq_alg.value:<40} ║")
    print("╠═══════════════════════════════════════════════════════════════╣")
    print(f"║ Weight Granularity:   {quant_config.weight_granularity.value:<40} ║")
    print(f"║ Activation Gran.:     {quant_config.activation_granularity.value:<40} ║")
    print(f"║ Group Size:           {quant_config.group_size:<40} ║")
    print(f"║ Weight Bits:          {quant_config.w_bits:<40} ║")
    print(f"║ Activation Bits:      {quant_config.a_bits:<40} ║")
    print("╠═══════════════════════════════════════════════════════════════╣")
    print(f"║ Transform Class:      {quant_config.transform_class:<40} ║")
    print(f"║ Transform Class R1:   {quant_config.transform_class_r1:<40} ║")
    print(f"║ Transform Class R2:   {quant_config.transform_class_r2:<40} ║")
    print(f"║ Matrix Init:          {quant_config.matrix_init:<40} ║")
    print(f"║ Hadamard Group:       {quant_config.hadamard_group_size:<40} ║")
    print("╠═══════════════════════════════════════════════════════════════╣")
    print(f"║ Symmetric:            {str(quant_config.symmetric):<40} ║")
    print(f"║ Calibrate Acts:       {str(quant_config.calibrate_activations):<40} ║")
    print(f"║ Observer Type:        {quant_config.observer_type:<40} ║")
    print("╠═══════════════════════════════════════════════════════════════╣")
    print(f"║ GPTQ Quant Order:     {quant_config.quantization_order:<40} ║")
    print(f"║ GPTQ Rel Damp:        {quant_config.rel_damp:<40} ║")
    print("╚═══════════════════════════════════════════════════════════════╝\n")

    # Print optimization configuration if available
    if opt_config is not None:
        print("\n╔═══════════════════════════════════════════════════════════════╗")
        print("║              OPTIMIZATION CONFIGURATION                       ║")
        print("╠═══════════════════════════════════════════════════════════════╣")
        print(f"║ Learning Rate:        {opt_config.learning_rate:<40} ║")
        print(f"║ Weight Decay:         {opt_config.weight_decay:<40} ║")
        print(f"║ Betas:                {str(opt_config.betas):<40} ║")
        print(f"║ Max Steps:            {opt_config.max_steps:<40} ║")
        print(f"║ LR Scheduler:         {opt_config.lr_scheduler if opt_config.lr_scheduler else 'None':<40} ║")
        print(f"║ Warmup Iterations:    {opt_config.warmup_iters:<40} ║")
        print(f"║ Warmup Start Factor:  {opt_config.warmup_start_factor:<40} ║")
        print("╠═══════════════════════════════════════════════════════════════╣")
        print(f"║ Loss Function:        {opt_config.loss_function:<40} ║")
        print(f"║ Distance Metric:      {opt_config.distance_metric:<40} ║")
        print(f"║ Temperature:          {opt_config.temperature:<40} ║")
        print(f"║ Reg Lambda:           {opt_config.reg_lambda:<40} ║")
        print("╠═══════════════════════════════════════════════════════════════╣")
        print(f"║ Matrix Param:         {opt_config.mat_param:<40} ║")
        print(f"║ Single Transform Mtx: {str(opt_config.single_transform_matrix):<40} ║")
        print(f"║ Block Diag Init:      {str(opt_config.block_diag_init):<40} ║")
        print(f"║ Add Random Noise:     {str(opt_config.add_rand_noise):<40} ║")
        print("╚═══════════════════════════════════════════════════════════════╝\n")


if __name__ == "__main__":  # pragma: no cover - manual invocation helper
    run_cfg, quant_cfg, opt_cfg = parse_args()
    print_configs(run_cfg, quant_cfg, opt_cfg)

__all__ = ["RunConfig", "parse_args", "print_configs"]