"""Methods for selecting a subset of classes to compute gradients for."""
import abc
import dataclasses
from typing import Dict, Optional, Union

import torch

###############################################################################


@dataclasses.dataclass
class ClassSubsetSelectorInput:
    """Represents the inputs to a class subset selector for a single example."""

    # example[*].shape = [sequence_length], dtype=int32
    example: Dict[str, torch.Tensor]

    # shape = [n_classes]
    logits: torch.Tensor

    # Optional label, required for some class subset selectors. If provided
    # as a Tensor, then label.shape = [], dtype=int32.
    label: Optional[Union[int, torch.Tensor]] = None


class ClassSubsetSelectorAbc(abc.ABC):
    """ABC for classes that select a subset of classes to compute gradients for."""

    @abc.abstractmethod
    def select_classes(self, example_info: ClassSubsetSelectorInput) -> torch.Tensor:
        """Selects the classes to compute gradients for given a particular example.
        
        Return:
            A tensor with shape = [n_selected_classes] and dtype=int32 consisting of the
            class indices to compute gradients for.
        """
        raise NotImplementedError

    @classmethod
    def create(cls, **kwargs):
        return cls(**kwargs)


###############################################################################


@dataclasses.dataclass
class ExhaustiveClassSubsetSelector(ClassSubsetSelectorAbc):
    
    def select_classes(self, example_info: ClassSubsetSelectorInput) -> torch.Tensor:
        return torch.arange(0, example_info.logits.shape[-1], dtype=torch.int32, device=example_info.logits.device)


@dataclasses.dataclass
class LabelledClassSubsetSelector(ClassSubsetSelectorAbc):
    
    def select_classes(self, example_info: ClassSubsetSelectorInput) -> torch.Tensor:
        ret = torch.zeros([1], dtype=torch.int32, device=example_info.logits.device)
        ret[0] = example_info.label
        return ret


###############################################################################


@dataclasses.dataclass
class TopClassesSubsetSelector(ClassSubsetSelectorAbc):
    # Exclusive.
    min_prob: Optional[float] = None

    # Both, if provided, are inclusive.
    min_classes: Optional[int] = None
    max_classes: Optional[int] = None

    def __post_init__(self):
        if self.min_prob is None and self.max_classes is None:
            raise ValueError('At least one of min_prob or max_classes must be set.')
        if self.min_classes is not None and self.max_classes is not None:
            if self.min_classes > self.max_classes:
                raise ValueError('The min_classes must be less than or equal to max_classes.')

        self._t_min_classes = None
        self._t_max_classes = None

    def _get_t_min_classes(self, device):
        if self._t_min_classes is None:
            self._t_min_classes = torch.tensor(self.min_classes, dtype=torch.int32, device=device)
        return self._t_min_classes

    def _get_t_max_classes(self, device):
        if self._t_max_classes is None:
            self._t_max_classes = torch.tensor(self.max_classes, dtype=torch.int32, device=device)
        return self._t_max_classes

    def _select_classes(self, example_info: ClassSubsetSelectorInput) -> torch.Tensor:
        if self.min_classes is not None and self.min_classes > example_info.logits.shape[-1]:
            raise ValueError('The min_classes must be less than or equal to the number of classes.')

        probs = torch.softmax(example_info.logits, dim=-1)
        device = probs.device

        s_probs, s_class_indices = torch.sort(probs, dim=-1, descending=True)

        n_selected = None if self.max_classes is None else self._get_t_max_classes(device)

        if self.min_prob is not None:
            ub_min_prob = torch.sum((s_probs > self.min_prob).type(torch.int32))
            if n_selected is None:
                n_selected = ub_min_prob
            else:
                n_selected = torch.minimum(n_selected, ub_min_prob)

        if self.min_classes is not None:
            n_selected = torch.maximum(n_selected, self._get_t_min_classes(device))

        return s_class_indices[:n_selected].type(torch.int32)

    def select_classes(self, example_info: ClassSubsetSelectorInput) -> torch.Tensor:
        # Returns class indices in descending order of class probability.
        with torch.no_grad():
            return self._select_classes(example_info)
