"""Module containing classes and functions for handling image classification tasks."""

from __future__ import annotations

from typing import Tuple

import torch
import torch.nn.functional as F  # noqa: N812
from .abstract_task import AbstractTask
from torch import nn

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]


class ImageClassificationModelOutput:
    """Class for handling the output of an image classification model."""

    softmax: nn.Module = torch.nn.Softmax(-1)
    loss_temperature: float = 1.0

    @staticmethod
    def get_output(
        model: nn.Module,
        weights: dict[str, torch.Tensor],
        buffers: dict[str, torch.Tensor],
        image: torch.Tensor,
        label: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate the output margins for the given model and input.

        Args:
        ----
            model (nn.Module): The neural network model.
            weights (dict[str, torch.Tensor]): Dictionary containing model weights.
            buffers (dict[str, torch.Tensor]): Dictionary containing model buffers.
            image (torch.Tensor): The input image tensor.
            label (torch.Tensor): The ground truth label tensor.

        Returns:
        -------
            torch.Tensor: The calculated output margins.

        """
        logits = torch.func.functional_call(
            model,
            (weights, buffers),
            image.unsqueeze(0),
        )
        bindex = torch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
        logits_correct = logits[bindex, label.unsqueeze(0)]

        cloned_logits = logits.clone()
        cloned_logits[bindex, label.unsqueeze(0)] = torch.tensor(
            -torch.inf,
            device=logits.device,
            dtype=logits.dtype,
        )

        margins = logits_correct - cloned_logits.logsumexp(dim=-1)
        return margins.sum()

    def get_out_to_loss_grad(
        self,
        model: nn.Module,
        weights: dict[str, torch.Tensor],
        buffers: dict[str, torch.Tensor],
        batch: tuple[torch.Tensor, torch.Tensor],
    ) -> torch.Tensor:
        """Calculate the gradient of the output with respect to the loss.

        Args:
        ----
            model (nn.Module): The neural network model.
            weights (dict[str, torch.Tensor]): Dictionary containing model weights.
            buffers (dict[str, torch.Tensor]): Dictionary containing model buffers.
            batch (tuple[torch.Tensor, torch.Tensor]): A tuple containing images and labels.

        Returns:
        -------
            torch.Tensor: The gradient of the output with respect to the loss.

        """
        images, labels = batch
        logits = torch.func.functional_call(model, (weights, buffers), images)
        ps = self.softmax(logits / self.loss_temperature)[
            torch.arange(logits.size(0)),
            labels,
        ]
        return (1 - ps).clone().detach().unsqueeze(-1)


class ClassificationTask(AbstractTask):
    """A task for image classification using a neural network model.

    This class provides methods to calculate training loss, measurement, and other utilities
    for image classification tasks.
    """

    def __init__(
        self,
        influence_modules: list[str],
        representation_module: str,
        device: torch.device = "cpu",
        generator: torch.Generator | None = None,
    ) -> None:
        """Initialize the ClassificationTask.

        Args:
        ----
            device (torch.device): The device to run the model on. Defaults to "cpu".
            generator (Optional[torch.Generator]): A random number generator for sampling. Defaults to None.

        """
        super().__init__(device=device, generator=generator)

        self._influence_modules = influence_modules
        self._representation_module = representation_module

    def get_train_loss(
        self,
        model: nn.Module,
        batch: BATCH_DTYPE,
        parameter_and_buffer_dicts: dict[str, torch.Tensor] | None = None,
        sample: bool = False,
        reduction: str = "sum",
    ) -> torch.Tensor:
        """Calculate the training loss for the given model and batch.

        Args:
        ----
            model (nn.Module): The neural network model.
            batch (BATCH_DTYPE): A tuple containing images and labels.
            parameter_and_buffer_dicts (Optional[dict[str, torch.Tensor]], optional):
                Dictionary containing model parameters and buffers. Defaults to None.
            sample (bool, optional): Whether to sample labels. Defaults to False.
            reduction (str, optional): Specifies the reduction to apply to the output. Defaults to "sum".

        Returns:
        -------
            torch.Tensor: The calculated training loss.

        """
        images, labels = batch
        if parameter_and_buffer_dicts is None:
            images, labels = images.to(self.device), labels.to(self.device)
            outputs = model(images)
        else:
            images = images.unsqueeze(0).to(self.device)
            labels = labels.unsqueeze(0).to(self.device)
            params, buffers = parameter_and_buffer_dicts
            outputs = torch.func.functional_call(model, (params, buffers), (images,))

        if not sample:
            return F.cross_entropy(outputs, labels.to(self.device), reduction=reduction)
        else:
            with torch.no_grad():
                probs = torch.nn.functional.softmax(outputs, dim=-1)
                sampled_labels = torch.multinomial(
                    probs,
                    num_samples=1,
                    generator=self.generator,
                ).flatten()
            return F.cross_entropy(
                outputs,
                sampled_labels.detach(),
                reduction=reduction,
            )

    def get_measurement(
        self,
        model: nn.Module,
        batch: BATCH_DTYPE,
        parameter_and_buffer_dicts: dict[str, torch.Tensor] | None = None,
        sample: bool = False,
        reduction: str = "sum",
    ) -> torch.Tensor:
        """Calculate the measurement for the given model and batch.

        Args:
        ----
            model (nn.Module): The neural network model.
            batch (BATCH_DTYPE): A tuple containing images and labels.
            parameter_and_buffer_dicts (Optional[Union[dict[str, torch.Tensor]]], optional):
                Dictionary containing model parameters and buffers. Defaults to None.
            sample (bool, optional): Whether to sample labels. Defaults to False.
            reduction (str, optional): Specifies the reduction to apply to the output. Defaults to "sum".

        Returns:
        -------
            torch.Tensor: The calculated measurement.

        """

        assert not sample, "This method is never called with sample=True"
        assert parameter_and_buffer_dicts is not None, "This method is never called with parameter_and_buffer_dicts=None"
        assert reduction == "sum", "This method is only called with reduction='sum'"

        images, labels = batch
        assert labels.ndim == 0, "Labels should be integers, not one-hot encoded"

        images = images.unsqueeze(0).to(self.device)
        labels = labels.unsqueeze(0).to(self.device)
        params, buffers = parameter_and_buffer_dicts
        logits = torch.func.functional_call(model, (params, buffers), (images,))

        assert logits.ndim == 2 and labels.ndim == 1
        logits_target = logits[torch.arange(logits.shape[0]), ..., labels]
        logits_except_target = torch.scatter(logits, dim=1, index=labels.unsqueeze(0), value=float("-inf"))
        hinge_score = logits_target - torch.max(logits_except_target, dim=-1).values
        return hinge_score.sum()  # NB: only have one sample anyway

    def get_batch_size(self, batch: BATCH_DTYPE) -> int:
        """Return the batch size given a batch of images and labels."""
        images, _ = batch
        return images.shape[0]

    def influence_modules(self) -> list[str]:
        """Return a list of module names that influence the model's output."""
        return self._influence_modules

    def representation_module(self) -> str:
        """Return the name of the module used for representation."""
        return self._representation_module

    def get_model_output(self) -> ImageClassificationModelOutput | None:
        """Return the model output for image classification."""
        return ImageClassificationModelOutput()
