from dataclasses import dataclass
from typing import Optional, Tuple


@dataclass
class OptimizationConfig:
    """Configuration for transform matrix optimization training."""

    learning_rate: float = 1e-3
    # Optional per-transform learning rates. If None, fall back to `learning_rate`.
    weight_decay: float = 0.0
    betas: Tuple[float, float] = (0.9, 0.999)
    max_steps: int = 1000
    lr_scheduler: Optional[str] = None  # Options: "cosine", "linear", or None
    warmup_iters: int = 0
    warmup_start_factor: float = 0.1
    loss_function: str = "output_distillation"  # Loss function for optimization
    distance_metric: str = "kl"  # Distance metric for loss computation
    reg_lambda: float = 1e-4  # Regularization lambda coefficient
    single_transform_matrix: bool = False
    block_diag_init: bool = False  # Whether to use block diagonal initialization
    add_rand_noise: bool = False
    temperature: float = 1.0  # Temperature for distillation loss
    mat_param: str = "learnable_inv"  # Matrix parametrization method for learned transforms
