from dataclasses import dataclass
from typing import Optional, Union, List

import torch


@dataclass(slots=True)
class RunConfig:
    """Container for all configuration options derived from CLI arguments.

    Each field corresponds to a user-provided (or defaulted) CLI flag. Enums
    provide normalized representations for algorithm / format choices.
    """
    model_name: str
    model_path: Optional[str]
    use_model_path: bool
    calibration_dataset: str
    dataset_path: Optional[str]
    use_dataset_path: bool
    seed: int
    n_samples: int
    shuffle_calibration: bool
    device: torch.device
    base_dtype: torch.dtype
    run_float_eval: bool
    batch_size: int
    num_workers: int
    max_samples: int
    calib_sequence_length: int = 2048,
    eval_tasks: Union[str, List[str]] = "wikitext2"
    num_fewshot: int = 0
    disable_thinking: bool = False


    def __str__(self) -> str:
        """Pretty-print the run configuration."""
        lines = [
            "╔═══════════════════════════════════════════════════════════════╗",
            "║                      RUN CONFIGURATION                        ║",
            "╠═══════════════════════════════════════════════════════════════╣",
            f"║ Model:                {self.model_name:<40} ║",
            f"║ Model Path:           {self.model_path if self.model_path else 'None':<40} ║",
            f"║ Use Model Path:       {str(self.use_model_path):<40} ║",
            f"║ Device:               {str(self.device):<40} ║",
            f"║ Base dtype:           {str(self.base_dtype):<40} ║",
            "╠═══════════════════════════════════════════════════════════════╣",
            f"║ Calibration Dataset:  {self.calibration_dataset:<40} ║",
            f"║ Dataset Path:         {self.dataset_path if self.dataset_path else 'None':<40} ║",
            f"║ Use Dataset Path:     {str(self.use_dataset_path):<40} ║",
            f"║ Calibration Samples:  {self.n_samples:<40} ║",
            f"║ Shuffle Calibration:  {str(self.shuffle_calibration):<40} ║",
            f"║ Run Float Eval:       {str(self.run_float_eval):<40} ║",
            "╠═══════════════════════════════════════════════════════════════╣",
            f"║ Batch Size:           {self.batch_size:<40} ║",
            f"║ Max Samples:          {self.max_samples if self.max_samples else 'All':<40} ║",
            f"║ Num Workers:          {self.num_workers:<40} ║",
        ]

        lines.extend([
            "╠═══════════════════════════════════════════════════════════════╣",
            f"║ Seed:                 {self.seed:<40} ║",
            "╚═══════════════════════════════════════════════════════════════╝",
        ])

        return "\n".join(lines)