# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Facilities for reporting and collecting training statistics across
multiple processes and devices. The interface is designed to minimize
synchronization overhead as well as the amount of boilerplate in user
code."""

import re
import numpy as np
import torch

from distrl.models.diffusion_models.edm2.original import dnnlib
from . import misc

#----------------------------------------------------------------------------

_num_moments    = 3             # [num_scalars, sum_of_scalars, sum_of_squares]
_reduce_dtype   = torch.float32 # Data type to use for initial per-tensor reduction.
_counter_dtype  = torch.float64 # Data type to use for the internal counters.
_rank           = 0             # Rank of the current process.
_sync_device    = None          # Device to use for multiprocess communication. None = single-process.
_sync_called    = False         # Has _sync() been called yet?
_counters       = dict()        # Running counters on each device, updated by report(): name => device => torch.Tensor
_cumulative     = dict()        # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor

#----------------------------------------------------------------------------

def init_multiprocessing(rank, sync_device):
    r"""Initializes `torch_utils.training_stats` for collecting statistics
    across multiple processes.

    This function must be called after
    `torch.distributed.init_process_group()` and before `Collector.update()`.
    The call is not necessary if multi-process collection is not needed.

    Args:
        rank:           Rank of the current process.
        sync_device:    PyTorch device to use for inter-process
                        communication, or None to disable multi-process
                        collection. Typically `torch.device('cuda', rank)`.
    """
    global _rank, _sync_device
    assert not _sync_called
    _rank = rank
    _sync_device = sync_device

#----------------------------------------------------------------------------

@misc.profiled_function
def report(name, value):
    r"""Broadcasts the given set of scalars to all interested instances of
    `Collector`, across device and process boundaries. NaNs and Infs are
    ignored.

    This function is expected to be extremely cheap and can be safely
    called from anywhere in the training loop, loss function, or inside a
    `torch.nn.Module`.

    Warning: The current implementation expects the set of unique names to
    be consistent across processes. Please make sure that `report()` is
    called at least once for each unique name by each process, and in the
    same order. If a given process has no scalars to broadcast, it can do
    `report(name, [])` (empty list).

    Args:
        name:   Arbitrary string specifying the name of the statistic.
                Averages are accumulated separately for each unique name.
        value:  Arbitrary set of scalars. Can be a list, tuple,
                NumPy array, PyTorch tensor, or Python scalar.

    Returns:
        The same `value` that was passed in.
    """
    if name not in _counters:
        _counters[name] = dict()

    elems = torch.as_tensor(value)
    if elems.numel() == 0:
        return value

    elems = elems.detach().flatten().to(_reduce_dtype)
    square = elems.square()
    finite = square.isfinite()
    moments = torch.stack([
        finite.sum(dtype=_reduce_dtype),
        torch.where(finite, elems, 0).sum(),
        torch.where(finite, square, 0).sum(),
    ])
    assert moments.ndim == 1 and moments.shape[0] == _num_moments
    moments = moments.to(_counter_dtype)

    device = moments.device
    if device not in _counters[name]:
        _counters[name][device] = torch.zeros_like(moments)
    _counters[name][device].add_(moments)
    return value

#----------------------------------------------------------------------------

def report0(name, value):
    r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
    but ignores any scalars provided by the other processes.
    See `report()` for further details.
    """
    report(name, value if _rank == 0 else [])
    return value

#----------------------------------------------------------------------------

class Collector:
    r"""Collects the scalars broadcasted by `report()` and `report0()` and
    computes their long-term averages (mean and standard deviation) over
    user-defined periods of time.

    The averages are first collected into internal counters that are not
    directly visible to the user. They are then copied to the user-visible
    state as a result of calling `update()` and can then be queried using
    `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
    internal counters for the next round, so that the user-visible state
    effectively reflects averages collected between the last two calls to
    `update()`.

    Args:
        regex:          Regular expression defining which statistics to
                        collect. The default is to collect everything.
        keep_previous:  Whether to retain the previous averages if no
                        scalars were collected on a given round
                        (default: False).
    """
    def __init__(self, regex='.*', keep_previous=False):
        self._regex = re.compile(regex)
        self._keep_previous = keep_previous
        self._cumulative = dict()
        self._moments = dict()
        self.update()
        self._moments.clear()

    def names(self):
        r"""Returns the names of all statistics broadcasted so far that
        match the regular expression specified at construction time.
        """
        return [name for name in _counters if self._regex.fullmatch(name)]

    def update(self):
        r"""Copies current values of the internal counters to the
        user-visible state and resets them for the next round.

        If `keep_previous=True` was specified at construction time, the
        operation is skipped for statistics that have received no scalars
        since the last update, retaining their previous averages.

        This method performs a number of GPU-to-CPU transfers and one
        `torch.distributed.all_reduce()`. It is intended to be called
        periodically in the main training loop, typically once every
        N training steps.
        """
        if not self._keep_previous:
            self._moments.clear()
        for name, cumulative in _sync(self.names()):
            if name not in self._cumulative:
                self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
            delta = cumulative - self._cumulative[name]
            self._cumulative[name].copy_(cumulative)
            if float(delta[0]) != 0:
                self._moments[name] = delta

    def _get_delta(self, name):
        r"""Returns the raw moments that were accumulated for the given
        statistic between the last two calls to `update()`, or zero if
        no scalars were collected.
        """
        assert self._regex.fullmatch(name)
        if name not in self._moments:
            self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
        return self._moments[name]

    def num(self, name):
        r"""Returns the number of scalars that were accumulated for the given
        statistic between the last two calls to `update()`, or zero if
        no scalars were collected.
        """
        delta = self._get_delta(name)
        return int(delta[0])

    def mean(self, name):
        r"""Returns the mean of the scalars that were accumulated for the
        given statistic between the last two calls to `update()`, or NaN if
        no scalars were collected.
        """
        delta = self._get_delta(name)
        if int(delta[0]) == 0:
            return float('nan')
        return float(delta[1] / delta[0])

    def std(self, name):
        r"""Returns the standard deviation of the scalars that were
        accumulated for the given statistic between the last two calls to
        `update()`, or NaN if no scalars were collected.
        """
        delta = self._get_delta(name)
        if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
            return float('nan')
        if int(delta[0]) == 1:
            return float(0)
        mean = float(delta[1] / delta[0])
        raw_var = float(delta[2] / delta[0])
        return np.sqrt(max(raw_var - np.square(mean), 0))

    def as_dict(self):
        r"""Returns the averages accumulated between the last two calls to
        `update()` as an `dnnlib.EasyDict`. The contents are as follows:

            dnnlib.EasyDict(
                NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
                ...
            )
        """
        stats = dnnlib.EasyDict()
        for name in self.names():
            stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
        return stats

    def __getitem__(self, name):
        r"""Convenience getter.
        `collector[name]` is a synonym for `collector.mean(name)`.
        """
        return self.mean(name)

#----------------------------------------------------------------------------

def _sync(names):
    r"""Synchronize the global cumulative counters across devices and
    processes. Called internally by `Collector.update()`.
    """
    if len(names) == 0:
        return []
    global _sync_called
    _sync_called = True

    # Check that all ranks have the same set of names.
    if _sync_device is not None:
        value = hash(tuple(tuple(ord(char) for char in name) for name in names))
        other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
        torch.distributed.broadcast(tensor=other, src=0)
        if value != int(other.cpu()):
            raise ValueError('Training statistics are inconsistent between ranks')

    # Collect deltas within current rank.
    deltas = []
    device = _sync_device if _sync_device is not None else torch.device('cpu')
    for name in names:
        delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
        for counter in _counters[name].values():
            delta.add_(counter.to(device))
            counter.copy_(torch.zeros_like(counter))
        deltas.append(delta)
    deltas = torch.stack(deltas)

    # Sum deltas across ranks.
    if _sync_device is not None:
        torch.distributed.all_reduce(deltas)

    # Update cumulative values.
    deltas = deltas.cpu()
    for idx, name in enumerate(names):
        if name not in _cumulative:
            _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
        _cumulative[name].add_(deltas[idx])

    # Return name-value pairs.
    return [(name, _cumulative[name]) for name in names]

#----------------------------------------------------------------------------
# Convenience.

default_collector = Collector()

#----------------------------------------------------------------------------
