# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
# --------------------------------------------------------
# Based on timm code base
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------

import torch

class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self, enabled=True):
        self._scaler = torch.amp.GradScaler(enabled=enabled)

    def __call__(self, loss, optimizer, clip_grad=None, skip_grad=None, parameters=None, create_graph=False, update_grad=True, compute_grad_norm=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            elif skip_grad is not None:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
                if norm >= skip_grad:
                    self._scaler.update()
                    return norm
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters) if compute_grad_norm else None
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
    return total_norm