import torch.optim as optim


class NewAdam(optim.Adam):
    """This is the same Adam but slightly different initialization function."""

    def __init__(
        self,
        params,
        lr=0.001,
        new_beta1=0.1,
        new_beta2=0.001,
        eps=1e-08,
        weight_decay=0,
        amsgrad=False,
    ):
        super().__init__(
            params,
            lr=lr,
            betas=(1 - new_beta1, 1 - new_beta2),
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
        )
