import torch_optimizer as optim


class NewRAdam(optim.RAdam):
    """This is the same RAdam but slightly different initialization function."""

    def __init__(
        self,
        params,
        lr=0.001,
        new_beta1=0.1,
        new_beta2=0.001,
        eps=1e-08,
        weight_decay=0,
    ):
        super().__init__(
            params,
            lr=lr,
            betas=(1 - new_beta1, 1 - new_beta2),
            eps=eps,
            weight_decay=weight_decay,
        )
