from dataclasses import dataclass
from typing import Union

from omegaconf import MISSING


@dataclass
class OptimizerConfig:
    name: str = MISSING
    group: str = "optimizer"
    lr: float = MISSING
    grad_clip: Union[float, None] = None
    scheduler: Union[str, None] = "cos"


@dataclass
class AdamW(OptimizerConfig):
    name: str = "adamw"
    lr: float = 0.001
    weight_decay: float = 0.001


@dataclass
class SGD(OptimizerConfig):
    name: str = "sgd"
    lr: float = 0.01
    weight_decay: float = 0.00005
    momentum: float = 0.9


@dataclass
class Adam(OptimizerConfig):
    name: str = "adam"
    lr: float = 0.0001
    weight_decay: float = 0.0
