import torch
from .common import DiversityLoss, cosine_similarity, subset_batch


class InputGradientDiversityLoss(DiversityLoss):
    """Diversify salience maps."""
    def __init__(self, mode):
        super().__init__()
        self.mode = mode

    def get_salience_maps(self, images, labels, logits,
                          batch_start=None, batch_stop=None):
        logits = subset_batch(logits, 1, batch_start, batch_stop)
        labels = subset_batch(labels, 0, batch_start, batch_stop)
        num_models, batch_size, num_classes = logits.shape  # (N, B, L).

        if self.mode == "max":
            output = logits.max(dim=2)[0]  # (N, B).
        elif self.mode == "min":
            output = logits.min(dim=2)[0]  # (N, B).
        elif self.mode == "gt":
            output = logits.take_along_dim(labels[None, :, None], 2).squeeze(2)  # (N, B).
        elif self.mode == "sum":
            output = logits.sum(2)  # (N, B).
        else:
            raise ValueError(f"Unknown diversity loss mode {self.config.mode}.")
        assert output.ndim == 2

        image_grad = torch.autograd.grad(
            output.sum(), images, retain_graph=True, create_graph=True
        )[0]  # (N, B, *).
        return subset_batch(image_grad, 1, batch_start, batch_stop)

    def __call__(self, images, labels, logits, features, feature_maps,
                 batch_start=None, batch_stop=None):
        """Compute the loss.

        Args:
            images: Input images with shape (N, B, *).
            labels: Ground truth labels with shape (B).
            logits: Model outputs with shape (N, B, L).
            features (unused): List of model embeddings with shapes (B, D).
            feature_maps (unused): List of model activation maps with shapes (B, C, H, W).
            batch_start: Compute loss for a subset of the batch.
            batch_stop: Compute loss for a subset of the batch.
        """
        image_grad = self.get_salience_maps(images, labels, logits,
                                            batch_start=batch_start, batch_stop=batch_stop)
        return cosine_similarity(image_grad)
