"""Methods for selecting a sequence position to analyze for a given example."""
import abc
import dataclasses
from typing import Dict, Optional

import torch

from npeff_torch.util import torch_utils

###############################################################################
# typedefs
PositionSelection = Optional[torch.Tensor]
###############################################################################


@dataclasses.dataclass
class PositionSelectorInput:
    """Represents the inputs to select sequence positions for a batch of examples."""

    # example[*].shape = [batch, sequence_length], dtype=int32
    examples: Dict[str, torch.Tensor]

    # The number of non-padding tokens in each example. Non-padding tokens
    # are expected to be a prefix of the token sequence.
    # shape = [batch], dtype=int32
    n_non_paddings: torch.Tensor

    # The logits for the model's predictions on the examples.
    # shape = [batch, sequence?, n_classes]
    logits: Optional[torch.Tensor] = None

    # The value that should be predicted for each prediction.
    # shape = [batch, sequence?], dtype=int32
    labels: Optional[torch.Tensor] = None


class PositionSelectorAbc(abc.ABC):
    """ABC for classes that select the sequence positions for a batch of examples."""
    
    @abc.abstractmethod
    def select_positions(self, batch_info: PositionSelectorInput) -> PositionSelection:
        """Selects the sequence positions for a batch of examples.
        
        Return:
            A tensor with shape = [batch] and dtype=int32 consisting of the selected sequence position
            for each example in the batch. Can also return None if the sequence positions are not needed
            downstream such as for sequence classification models. If a particular value in the returned
            tensor is negative, this indicates further processing to skip the example.
        """
        raise NotImplementedError

    @classmethod
    def create(cls, **kwargs):
        return cls(**kwargs)


###############################################################################


class NoopPositionSelector(PositionSelectorAbc):
    """Selects None for sequence positions, intended to be used when sequence positions are not needed."""

    def select_positions(self, batch_info: PositionSelectorInput) -> PositionSelection:
        return None


@dataclasses.dataclass
class UniformRandomPositionSelector(PositionSelectorAbc):
    """Selects positions uniformly at random in some range."""

    # Inclusive.
    min_position: Optional[int] = None
    # Exclusive.
    max_position: Optional[int] = None

    # If set to true, the last position within a sequence will not be selected. This
    # is intended to be used for autoregressive models. Will NOT have an effect on
    # max_position if max_position is less than the sequence length.
    exclude_last_position: bool = True

    def select_positions(self, batch_info: PositionSelectorInput) -> PositionSelection:
        # TODO: This function assumes that the n_non_padding for each example will always be
        # greater than min_position.

        low = self.min_position if self.min_position is not None else 0

        n_non_paddings = batch_info.n_non_paddings
        if self.max_position is not None:
            high = torch.minimum(self.max_position, n_non_paddings)
        else:
            high = n_non_paddings
        
        positions = torch_utils.randint32(low, high, size=n_non_paddings.shape, device=n_non_paddings.device)
        if self.exclude_last_position:
            last_position_indicator = (positions == (n_non_paddings - 1)).type(torch.int32)
            positions -= last_position_indicator

        return positions


###############################################################################


class LastPositionSelector(PositionSelectorAbc):
    """Selects the last non-padding position."""

    def select_positions(self, batch_info: PositionSelectorInput) -> PositionSelection:
        # NOTE: Assumes that there are no empty examples.
        return batch_info.n_non_paddings - 1


###############################################################################


@dataclasses.dataclass
class SparseFeatureCircuitsPositionSelector(PositionSelectorAbc):
    """Position selector from the sparse feature circuits paper.
    
    The token filtering is described in Section H.1 of https://arxiv.org/pdf/2403.19647
    Here, we filter the positions given that method and then select one uniformly at
    random.

    # TODO: NEED SOMETHING ELSE INSTEAD OF NONE, None specifies no selection, need to indicate a skip. (maybe use -1)
    # If no positions match, then we return None.
    """

    # The model must make the correct prediction with a cross entropy lower than this to be considered.
    max_cross_entropy_nats: float

    # Whether to include positions corresponding to bigrams occuring earlier in the context.
    exclude_induction: bool = True

    def _compute_correct_predictions_mask(self, batch_info: PositionSelectorInput) -> torch.Tensor:
        # ret.shape = [batch, n_tokens], dtype=bool
        predictions = torch.argmax(batch_info.logits, dim=-1)
        return predictions == batch_info.labels

    def _compute_cross_entropy_mask(self, batch_info: PositionSelectorInput) -> torch.Tensor:
        # ret.shape = [batch, n_tokens], dtype=bool
        # TODO: Double check that the gather is correct here.
        label_cross_entropies = -torch.gather(torch.log_softmax(batch_info.logits, dim=-1), -1, batch_info.labels[..., None].type(torch.int64))
        label_cross_entropies = torch.squeeze(label_cross_entropies, dim=-1)
        return label_cross_entropies <= self.max_cross_entropy_nats

    def _compute_non_induction_mask(self, batch_info: PositionSelectorInput) -> torch.Tensor:
        # ret.shape = [batch, n_tokens], dtype=bool
        first_token = batch_info.examples['input_ids']
        second_token = torch.roll(batch_info.examples['input_ids'], shifts=-1, dims=-1)

        bigram_match_mask = first_token[..., :, None] == first_token[..., None, :]
        bigram_match_mask &= second_token[..., :, None] == second_token[..., None, :]

        positions = torch.arange(0, first_token.shape[-1], dtype=torch.int32, device=bigram_match_mask.device)
        preceding_mask = positions[:, None] > positions[None, :]

        induction_mask = torch.any(bigram_match_mask & preceding_mask[None, ...], dim=-1)

        return ~induction_mask

    def select_positions(self, batch_info: PositionSelectorInput) -> PositionSelection:
        # The predictions/labels are assumed to have a sequence dimension.
        
        non_padding_mask = batch_info.examples['attention_mask'] != 0
        # Always exclude the last position
        non_padding_mask[:, batch_info.n_non_paddings - 1] = False

        correct_preds_mask = self._compute_correct_predictions_mask(batch_info)
        cross_entropy_mask = self._compute_cross_entropy_mask(batch_info)

        mask = non_padding_mask & correct_preds_mask & cross_entropy_mask

        if self.exclude_induction:
            mask &= self._compute_non_induction_mask(batch_info)

        # valid_indices.shape = [nnz, 2]
        # valid_indices[i] = [example_index, sequence_position]
        valid_indices = torch.nonzero(mask, as_tuple=False)

        batch_size = mask.shape[0]
        device = batch_info.logits.device

        positions = torch.zeros([batch_size], dtype=torch.int32, device=device)
        for example_index in range(batch_size):
            example_valid_indices_mask = valid_indices[:, 0] == example_index
            example_valid_positions = valid_indices[:, 1][example_valid_indices_mask]
            if len(example_valid_positions) == 0:
                positions[example_index] = -1
            else:
                index = torch.randint(0, len(example_valid_positions), (), device=example_valid_positions.device)
                positions[example_index] = example_valid_positions[index]

        return positions
