import torch_optimizer as optim
import torch


class NewLookahead(optim.Lookahead):
    """This is the same Lookahead but slightly different initialization function."""

    def __init__(
        self,
        params,
        lr=0.001,
        new_beta1=0.1,
        new_beta2=0.001,
        eps=1e-08,
        weight_decay=0,
        k=5,
        alpha=0.5
    ):
        optimizer = torch.optim.Adam(
            params,
            lr=lr,
            betas=(1 - new_beta1, 1 - new_beta2),
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(
            optimizer,
            k=k,
            alpha=alpha
        )
