

from dataclasses import dataclass
from typing import Callable, Optional

import torch


@dataclass
class OptimizerConfig:

    optimizer: str = 'adam'


    lr: Optional[float] = None


    min_lr: Optional[float] = None


    decoupled_lr: Optional[float] = None


    decoupled_min_lr: Optional[float] = None


    weight_decay: float = 0.01



    fp16: bool = False


    bf16: bool = False


    params_dtype: torch.dtype = torch.float32



    loss_scale: Optional[float] = None


    initial_loss_scale: float = 2**32


    min_loss_scale: float = 1.0


    loss_scale_window: float = 1000


    hysteresis: int = 2



    adam_beta1: float = 0.9


    adam_beta2: float = 0.999

    adam_eps: float = 1e-08



    sgd_momentum: float = 0.9



    use_distributed_optimizer: bool = False


    overlap_grad_reduce: bool = False


    overlap_param_gather: bool = False



    clip_grad: float = 1.0


    log_num_zeros_in_grad: bool = False


    barrier_with_L1_time: bool = False


    timers: Callable = None

