import torch
from torch import nn


class DDPGradientStatsHook:
    def __init__(self, ddp_module):
        try:
            ddp_module.register_comm_hook(self, self._hook_fn)
        except AttributeError:
            raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
        self._clear_state()

    def _clear_state(self):
        self.bucket_sq_norms_small_batch = []
        self.bucket_sq_norms_large_batch = []

    @staticmethod
    def _hook_fn(self, bucket):
        buf = bucket.buffer()
        self.bucket_sq_norms_small_batch.append(buf.pow(2).sum(dtype=torch.float32))
        fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
        def callback(fut):
            buf = fut.value()[0]
            self.bucket_sq_norms_large_batch.append(buf.pow(2).sum(dtype=torch.float32))
            return buf
        return fut.then(callback)

    def get_stats(self):
        sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
        sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
        self._clear_state()
        stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
        torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
        return stats[0].item(), stats[1].item()


class GradientNoiseScale:
    """Calculates the gradient noise scale (1 / SNR), or critical batch size,
    from _An Empirical Model of Large-Batch Training_,
    https://arxiv.org/abs/1812.06162).

    Args:
        beta (float): The decay factor for the exponential moving averages used to
            calculate the gradient noise scale.
            Default: 0.9998
        eps (float): Added for numerical stability.
            Default: 1e-8
    """

    def __init__(self, beta=0.9998, eps=1e-8):
        self.beta = beta
        self.eps = eps
        self.ema_sq_norm = 0.
        self.ema_var = 0.
        self.beta_cumprod = 1.
        self.gradient_noise_scale = float('nan')

    def state_dict(self):
        """Returns the state of the object as a :class:`dict`."""
        return dict(self.__dict__.items())

    def load_state_dict(self, state_dict):
        """Loads the object's state.
        Args:
            state_dict (dict): object state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
        """Updates the state with a new batch's gradient statistics, and returns the
        current gradient noise scale.

        Args:
            sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
                per sample gradients.
            sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
                per sample gradients.
            n_small_batch (int): The batch size of the individual microbatch or per sample
                gradients (1 if per sample).
            n_large_batch (int): The total batch size of the mean of the microbatch or
                per sample gradients.
        """
        est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
        est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
        self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
        self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
        self.beta_cumprod *= self.beta
        self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
        return self.gradient_noise_scale

    def get_gns(self):
        """Returns the current gradient noise scale."""
        return self.gradient_noise_scale

    def get_stats(self):
        """Returns the current (debiased) estimates of the squared mean gradient
        and gradient variance."""
        return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
