# -*- coding: UTF-8 -*-

import torch
# from torch._six import inf
from torch import inf


def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
    """
    Computes the gradient norm of the given parameters.
    
    Args:
        parameters (torch.Tensor or list[torch.Tensor]): The parameters whose gradient norm is to be computed.
        norm_type (float, optional): The type of norm to compute. Can be 'inf' for infinity norm, or any positive real number for p-norm. Defaults to 2.0 (Euclidean norm).
    
    Returns:
        torch.Tensor: The gradient norm of the given parameters.
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
                                norm_type)
    return total_norm


class NativeScalerWithGradNormCount:
    """
    A PyTorch gradient scaler that supports gradient clipping and automatic mixed precision (AMP) optimization.
    
    The `NativeScalerWithGradNormCount` class provides a convenient way to perform gradient scaling, clipping, and updating during the training process. It can be used with or without AMP optimization.
    
    When AMP is enabled, the class uses `torch.cuda.amp.GradScaler` to automatically scale the gradients and handle the unscaling and updating steps. When AMP is disabled, the class performs the gradient clipping and updating steps directly.
    
    The class provides the following methods:
    - `__call__(loss, optimizer=None, clip_grad=None, parameters=None, update_grad=True, backward_kwargs={})`
      - Performs the backward pass, gradient scaling/clipping, and optimizer update.
      - Returns the gradient norm if `parameters` is provided.
    - `state_dict()` and `load_state_dict(state_dict)`
      - Saves and loads the state of the gradient scaler.
    """

    state_dict_key = "amp_scaler"

    def __init__(self,
                 optimizer=None,
                 amp=False,
                 clip_grad=None):
        """
        Initializes a GradScaler object with the provided optimizer, automatic mixed precision (AMP) setting, and gradient clipping value.
        
        Args:
            optimizer (torch.optim.Optimizer, optional): The optimizer to use with the GradScaler. If not provided, the GradScaler will not be associated with an optimizer.
            amp (bool, optional): Whether to enable automatic mixed precision (AMP) for gradient scaling. Defaults to False.
            clip_grad (float, optional): The maximum norm of the gradients to be clipped. If provided, the gradients will be clipped to this value.
        """
        
        self._scaler = torch.cuda.amp.GradScaler()
        self.clip_grad = clip_grad
        self.optimizer = optimizer
        self.amp = amp

    def __call__(self, loss, optimizer=None, clip_grad=None, parameters=None, update_grad=True, backward_kwargs={}):
        """
        Performs a backward pass on the loss, scales the gradients if using automatic mixed precision (AMP), clips the gradients if specified, updates the optimizer, and resets the gradients.
        
        Args:
            loss (torch.Tensor): The loss to backpropagate.
            optimizer (torch.optim.Optimizer, optional): The optimizer to use. Defaults to the optimizer set in the constructor.
            clip_grad (float, optional): The maximum norm of the gradients. Gradients will be clipped to this value. Defaults to the clip_grad set in the constructor.
            parameters (Iterable[torch.nn.Parameter], optional): The parameters to clip the gradients for. Required if clip_grad is not None.
            update_grad (bool, optional): Whether to update the gradients and optimizer. Defaults to True.
            backward_kwargs (dict, optional): Additional keyword arguments to pass to the `backward()` call.
        
        Returns:
            float: The gradient norm, or None if update_grad is False.
        """

        if optimizer is None:
            optimizer = self.optimizer
        if clip_grad is None:
            clip_grad = self.clip_grad
        if self.amp:
            self._scaler.scale(loss).backward(**backward_kwargs)
        else:
            loss.backward(**backward_kwargs)

        norm = None
        if update_grad:
            if self.amp:
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            if clip_grad is not None:
                assert parameters is not None
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                if parameters is not None:
                    norm = get_grad_norm_(parameters)
            if self.amp:
                self._scaler.step(optimizer)
                self._scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
        return norm

    def state_dict(self):
        """
        Returns the state dictionary of the internal gradient scaler.
        """

        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        """
        Loads the state dictionary of the gradient scaler.
        
        Args:
            state_dict (dict): The state dictionary to load.
        """

        self._scaler.load_state_dict(state_dict)
