import torch
from torch import Tensor, nn

from torchjd.criterion._utils import _move_dim_to_front
from torchjd.criterion.base import Criterion


class AuxiliaryOutputsCriterion(Criterion):
    """
    :class:`~torchjd.criterion.base.Criterion` that applies a loss function to several outputs
    and a single target.

    .. hint::
        In the context of auxiliary outputs (for example in GoogLeNet's
        architecture in `Going deeper with convolutions <https://arxiv.org/pdf/1409.4842v1.pdf>`_),
        we have to compute the same loss function on several outputs (some auxiliary outputs and a
        final output). The target used to compute the losses is the same for all outputs.

    :param loss_function: The loss function to apply.
    :param dim: The dimension which separates the different outputs.

    .. admonition::
        Example

        Compute the mean squared error of all ``outputs`` with respect to a single ``target``, and
        returns the result as a vector.

        >>> import torch
        >>> from torch.nn import MSELoss
        >>> from torchjd.criterion import AuxiliaryOutputsCriterion
        >>>
        >>> auxiliary_output = torch.tensor([0.0, 2.0, 5.0])
        >>> final_output = torch.tensor([1.0, 3.0, 3.0])
        >>> outputs = torch.stack([auxiliary_output, final_output])
        >>> target = torch.tensor([1.0, 2.0, 3.0])
        >>>
        >>> criterion = AuxiliaryOutputsCriterion(loss_function=MSELoss(), dim=0)
        >>> loss_vector = criterion(outputs, target)
        >>> loss_vector
        tensor([1.6667, 0.3333])
    """

    def __init__(
        self,
        loss_function: nn.Module,
        dim: int,
    ):
        super().__init__()

        self.loss_function = loss_function
        self.dim = dim

    def forward(self, outputs: Tensor, target: Tensor) -> Tensor:
        """
        Splits the provided ``outputs`` along dimension ``dim`` to obtain the different output
        values (auxiliary outputs and main output, for instance). Computes the ``loss_function``
        between each of these values and the provided ``target``. Returns the result as a 1-D loss
        :class:`~torch.Tensor`.

        :param outputs: The model's outputs.
        :param target: The value that we expect.

        .. note::
            This method should generally not be called directly. Instead, the Criterion instance
            should be called (like in the usage example), as this will take care of running
            potential hooks in addition to calling ``forward``.
        """

        losses = []
        for output in _move_dim_to_front(outputs, self.dim):
            loss = self.loss_function(output, target)
            losses.append(loss)
        return torch.stack(losses)

    def __str__(self) -> str:
        return f"AuxiliaryOutputs (dim={self.dim}) {str(self.loss_function)}"
