import torch


class GradScaler(torch.amp.GradScaler):

    def __init__(
        self, 
        enabled: bool = True,
        scale_init: float = 2.**16,
        scale_min: float = 1.,
        growth_interval: int = 2000
    ):
        super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
        self._enabled = enabled
        self.scale_min = scale_min

        if not self._enabled:
            # We write scale=1 to log if the scaler is disabled
            self._scale = torch.tensor((1,), dtype=torch.float32, device='cuda')


    def update(self):

        if not self._enabled:
            return

        super().update()

        if self._scale < self.scale_min:
            super().update(self.scale_min)


