# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import flashy
import torch
from torch import autograd


class Balancer:
    """Loss balancer.

    The loss balancer combines losses together to compute gradients for the backward.
    Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
    not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
    `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
    the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
    going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
    interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.

    Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
    (with `avg` an exponential moving average over the updates),

        G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)

    If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
    standard sum of the partial gradients with the given weights.

    A call to the backward method of the balancer will compute the the partial gradients,
    combining all the losses and potentially rescaling the gradients,
    which can help stabilize the training and reason about multiple losses with varying scales.
    The obtained gradient with respect to `y` is then back-propagated to `f(...)`.

    Expected usage:

        weights = {'loss_a': 1, 'loss_b': 4}
        balancer = Balancer(weights, ...)
        losses: dict = {}
        losses['loss_a'] = compute_loss_a(x, y)
        losses['loss_b'] = compute_loss_b(x, y)
        if model.training():
            effective_loss = balancer.backward(losses, x)

    Args:
        weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
            from the backward method to match the weights keys to assign weight to each of the provided loss.
        balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
            overall gradient, rather than a constant multiplier.
        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
        emay_decay (float): EMA decay for averaging the norms.
        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
            when rescaling the gradients.
        epsilon (float): Epsilon value for numerical stability.
        monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
            coming from each loss, when calling `backward()`.
    """

    def __init__(
        self,
        weights: tp.Dict[str, float],
        balance_grads: bool = True,
        total_norm: float = 1.0,
        ema_decay: float = 0.999,
        per_batch_item: bool = True,
        epsilon: float = 1e-12,
        monitor: bool = False,
    ):
        self.weights = weights
        self.per_batch_item = per_batch_item
        self.total_norm = total_norm or 1.0
        self.averager = flashy.averager(ema_decay or 1.0)
        self.epsilon = epsilon
        self.monitor = monitor
        self.balance_grads = balance_grads
        self._metrics: tp.Dict[str, tp.Any] = {}

    @property
    def metrics(self):
        return self._metrics

    def backward(
        self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor
    ) -> torch.Tensor:
        """Compute the backward and return the effective train loss, e.g. the loss obtained from
        computing the effective weights. If `balance_grads` is True, the effective weights
        are the one that needs to be applied to each gradient to respect the desired relative
        scale of gradients coming from each loss.

        Args:
            losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
            input (torch.Tensor): the input of the losses, typically the output of the model.
                This should be the single point of dependence between the losses
                and the model being trained.
        """
        norms = {}
        grads = {}
        for name, loss in losses.items():
            # Compute partial derivative of the less with respect to the input.
            (grad,) = autograd.grad(loss, [input], retain_graph=True)
            if self.per_batch_item:
                # We do not average the gradient over the batch dimension.
                dims = tuple(range(1, grad.dim()))
                norm = grad.norm(dim=dims, p=2).mean()
            else:
                norm = grad.norm(p=2)
            norms[name] = norm
            grads[name] = grad

        count = 1
        if self.per_batch_item:
            count = len(grad)
        # Average norms across workers. Theoretically we should average the
        # squared norm, then take the sqrt, but it worked fine like that.
        avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
        # We approximate the total norm of the gradient as the sums of the norms.
        # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
        total = sum(avg_norms.values())

        self._metrics = {}
        if self.monitor:
            # Store the ratio of the total gradient represented by each loss.
            for k, v in avg_norms.items():
                self._metrics[f"ratio_{k}"] = v / total

        total_weights = sum([self.weights[k] for k in avg_norms])
        assert total_weights > 0.0
        desired_ratios = {k: w / total_weights for k, w in self.weights.items()}

        out_grad = torch.zeros_like(input)
        effective_loss = torch.tensor(0.0, device=input.device, dtype=input.dtype)
        for name, avg_norm in avg_norms.items():
            if self.balance_grads:
                # g_balanced = g / avg(||g||) * total_norm * desired_ratio
                scale = (
                    desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
                )
            else:
                # We just do regular weighted sum of the gradients.
                scale = self.weights[name]
            out_grad.add_(grads[name], alpha=scale)
            effective_loss += scale * losses[name].detach()
        # Send the computed partial derivative with respect to the output of the model to the model.
        input.backward(out_grad)
        return effective_loss
