from dataclasses import dataclass
from typing import Any, Iterable

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING


@dataclass
class OptimizerConfig:
    _target_: str = MISSING
    params: Any = MISSING
    lr: float = 0.001


@dataclass
class AdamConfig(OptimizerConfig):
    _target_: str = "torch.optim.adam.Adam"
    params: Any = MISSING
    lr: float = 0.001
    betas: tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-08
    weight_decay: float = 0
    amsgrad: bool = False


@dataclass
class AdamWConfig(OptimizerConfig):
    _target_: str = "torch.optim.adamw.AdamW"
    params: Any = MISSING
    lr: float = 0.001
    betas: tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-08
    weight_decay: float = 0.01
    amsgrad: bool = False


def register_optimizer_configs() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="optimizer_lib/optimizer",
        name="Adam",
        node=AdamConfig,
    )
    cs.store(
        group="optimizer_lib/optimizer",
        name="AdamW",
        node=AdamWConfig,
    )
