from .larc import LARC
import torch

class NewLARS(LARC):
    """This is the same LARS but slightly different initialization function."""

    def __init__(
        self,
        params,
        lr=0.001,
        momentum=0.9,
        trust_coefficient=0.02,
        clip=True,
        eps=1e-08,
        weight_decay=0,
    ):
        optimizer = torch.optim.SGD(
            params,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        super().__init__(
            optimizer,
            trust_coefficient=trust_coefficient,
            clip=clip,
            eps=eps,
        )