import torch
from typeguard import check_argument_types


class SGD(torch.optim.SGD):
    """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'

    Note that
    the arguments of the optimizer invoked by AbsTask.main()
    must have default value except for 'param'.

    I can't understand why only SGD.lr doesn't have the default value.
    """

    def __init__(
        self,
        params,
        lr: float = 0.1,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        nesterov: bool = False,
    ):
        assert check_argument_types()
        super().__init__(
            params,
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
