from dataclasses import dataclass


@dataclass
class TradTrainingArgs:
    """
    Traditional (gradient-descent-based) training arguments.

    Args:
        n_steps         Number of 'gradient-descent' steps in training.
        batch_size
        device          either 'cuda' or 'cpu'
        weight_init     Weight initialization method, can be 'kaiming_normal', 'xavier_normal' or uniform versions.
        lr_start        Learning rate at the start of training.
        lr_end          Learning rate at the end of training.
        weight_decay    L2 regularization
        patience        For early-stopping
        optim_type      'adam', 'sgd'..
        sched_type      'linear' or 'exponential' to have basic learning rate scheduling. Set it to None if not desired.
    """
    n_steps: int = 10_000
    batch_size: int | None = 256
    device: str = "cpu"
    weight_init: str = "kaiming_normal"

    lr_start: float = 5e-3
    lr_end: float = 1e-4
    weight_decay: float = 1e-06
    patience: int = 500
    optim_type: str = "adam"
    sched_type: str | None = "exponential"

    def __post_init__(self):
        assert self.device in {'cpu', 'cuda'}, "Unknown device"
        assert self.lr_end <= self.lr_start
