"""Common stuff for filtering of components."""
import abc
import collections
import dataclasses
from typing import List

from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import top_examples_common


###############################################################################


class ComponentFilterAbc(abc.ABC):
    """Abstract class for component filters to implement."""

    @abc.abstractmethod
    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        # Assumed to be non-empty.
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        raise NotImplementedError


###############################################################################


@dataclasses.dataclass
class ExampleSpecificFilter(ComponentFilterAbc):
    """Filters out components that are specific to an example.

    These components tend to have a very large coefficient for the top example and then
    have small coefficients for the other top examples. In some cases, the top few
    examples might have very large coefficients as well if they are similar.

    These get detected by seeing if that component's top example has a coefficient
    greater than or equal to `top_coefficient_threshold`. Then we see if the `tail_example_index`-th
    top example has a coefficient less than or equal to `tail_coefficient_threshold`.

    """

    top_coefficient_threshold: float

    tail_coefficient_threshold: float
    tail_example_index: int

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        if top_examples[0].coefficient < self.top_coefficient_threshold:
            return True

        # If fewer than tail_example_index + 1 examples are provided, we assume that
        # its coefficient is zero.
        if self.tail_example_index < len(top_examples):
            tail_coefficient = top_examples[self.tail_example_index].coefficient
        else:
            tail_coefficient = 0.0

        return tail_coefficient > self.tail_coefficient_threshold


###############################################################################


@dataclasses.dataclass
class SamePredictionFilter(ComponentFilterAbc):
    """Matches components where most/all of the top examples have the same prediction by the model."""

    # Only matches components where at least this fraction of the top examples have the same prediction.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        counts = collections.defaultdict(lambda: 0)
        for example in top_examples:
            counts[example.get_prediction()] += 1

        most_common_count = max(counts.values())
        return most_common_count / len(top_examples) >= self.fraction_threshold


@dataclasses.dataclass
class SameLabelFilter(ComponentFilterAbc):
    """Matches components where most/all of the top examples have the same ground truth label."""

    # Only matches components where at least this fraction of the top examples have the same label.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        counts = collections.defaultdict(lambda: 0)
        for example in top_examples:
            if example.label is None:
                raise ValueError
            counts[int(example.label)] += 1

        most_common_count = max(counts.values())
        return most_common_count / len(top_examples) >= self.fraction_threshold

        
###############################################################################


@dataclasses.dataclass
class SpecificPredictionFilter(ComponentFilterAbc):
    """Matches components where the model makes the same prediction on most/all of the top examples."""

    # The predicted label of interest.
    prediction: int

    # Only matches components where at least this fraction of the top examples have the particular prediction.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        count = 0
        for example in top_examples:
            if example.get_prediction() == self.prediction:
                count += 1

        return count / len(top_examples) >= self.fraction_threshold


@dataclasses.dataclass
class SpecificLabelFilter(ComponentFilterAbc):
    """Matches components where the model makes the same ground truth label on most/all of the top examples."""

    # The ground truth label of interest.
    label: int

    # Only matches components where at least this fraction of the top examples have the particular label.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        count = 0
        for example in top_examples:
            if example.label is None:
                raise ValueError
            if example.label == self.label:
                count += 1

        return count / len(top_examples) >= self.fraction_threshold


###############################################################################


@dataclasses.dataclass
class CorrectPredictionFilter(ComponentFilterAbc):
    """Matches components where the model makes the correct prediction on most/all of the top examples."""

    # Only matches components where at least this fraction of the top examples have a correct prediction.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        count = 0
        for example in top_examples:
            if example.prediction_is_correct():
                count += 1

        return count / len(top_examples) >= self.fraction_threshold


@dataclasses.dataclass
class WrongPredictionFilter(ComponentFilterAbc):
    """Matches components where the model makes the wrong prediction on most/all of the top examples."""

    # Only matches components where at least this fraction of the top examples have a wrong prediction.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        count = 0
        for example in top_examples:
            if not example.prediction_is_correct():
                count += 1
                
        return count / len(top_examples) >= self.fraction_threshold


###############################################################################


@dataclasses.dataclass
class FiltersLogicalAnd(ComponentFilterAbc):
    """Passes iff all of its filters pass."""

    filters: List[ComponentFilterAbc]

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        return all(f.does_component_pass(tokenizer, component_index, top_examples) for f in self.filters)


###############################################################################


@dataclasses.dataclass
class SameLabelSetFilter(ComponentFilterAbc):
    """Matches components where most/all of the top examples have the ground truth label belonging to a set of particular size."""

    # The size of the set.
    set_size: int

    # Only matches components where at least this fraction of the top examples have the label belong to 'set_size' most common
    # among the examples.
    fraction_threshold: float

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        counts = collections.defaultdict(lambda: 0)
        for example in top_examples:
            if example.label is None:
                raise ValueError
            counts[int(example.label)] += 1

        count = sum(sorted(counts.values(), reverse=True)[:self.set_size])
        return count / len(top_examples) >= self.fraction_threshold


###############################################################################


@dataclasses.dataclass
class LabelNotPredictionFilter(ComponentFilterAbc):
    prediction_fraction_threshold: float
    label_fraction_threshold: float

    def __post_init__(self):
        self._prediction_filter = SamePredictionFilter(fraction_threshold=self.prediction_fraction_threshold)
        self._label_filter = SameLabelFilter(fraction_threshold=self.label_fraction_threshold)

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        pred_passes = self._prediction_filter.does_component_pass(tokenizer, component_index, top_examples)
        label_passes = self._label_filter.does_component_pass(tokenizer, component_index, top_examples)
        return label_passes and not pred_passes
