"""Analysis related to Winogrande examples."""
import collections
import dataclasses
from typing import List, Tuple

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_common


###############################################################################
R"""
# TODO: Put some of this stuff somewhere more permanent. Maybe the file-level comment above.


Bad components:
    - Roughly single example/context [example_specific]:
        - These have a high coefficient (in worst cases close to 1) for a single example or pair of examples with
          the same context. The rest of the top examples have a small. These don't really common represent information
          across multiple examples (although the small coefficient examples show a bit of a pattern).
        - Detection: Top example has a high coefficient (above a threshold). The coefficient gets small quickly
          as you go down the top examples.


- option_presence
    - all of the top examples have an option present
    - Maybe could be negated if all of the predictions are that option (or the other option)?


"""

# Add a component_filter option (make as abc) to the generation of components?

###############################################################################


def _extract_og_sentence_and_options(sentences: List[str]) -> Tuple[str, List[str]]:
    sentences_words = [s.strip().split(' ') for s in sentences]
    shortest_length = min(len(ws) for ws in sentences_words)

    common_prefix_length = 0
    for i in range(shortest_length):
        word = sentences_words[0][i]
        if not all(ws[i] == word for ws in sentences_words):
            break
        common_prefix_length += 1

    common_suffix_length = 0
    for i in range(shortest_length):
        word = sentences_words[0][-1 - i]
        if not all(ws[-1 - i] == word for ws in sentences_words):
            break
        common_suffix_length += 1

    common_prefix = ' '.join(sentences_words[0][:common_prefix_length])
    if common_suffix_length > 0:
        common_suffix = ' '.join(sentences_words[0][-common_suffix_length:])
    else:
        common_suffix = ''
    
    og_sentence = f'{common_prefix} _ {common_suffix}'.strip()

    options = [
        ' '.join(
            ws[common_prefix_length:-common_suffix_length] if common_suffix_length > 0 else ws[common_prefix_length:]
        )
        for ws in sentences_words
    ]

    return og_sentence, options


###############################################################################


@dataclasses.dataclass
class WinograndeExample:

    # The original sentence with a _ taking the place of the pronoun of interest.
    sentence: str

    # The two options to fill in the _ in the sentence.
    options: Tuple[str, str]

    # Either 0 or 1. The index of correct option.
    label_index: int

    # Either 0 or 1. The index of the predicted option.
    prediction_index: int

    @property
    def label(self) -> str:
        return self.options[self.label_index]

    @property
    def prediction(self) -> str:
        return self.options[self.prediction_index]

    @classmethod
    def from_top_example_info(
        cls,
        tokenizer: PreTrainedTokenizer,
        example_info: 'top_examples_common.TopExampleInfo',
    ) -> 'WinograndeExample':
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        sentences = [
            tokenizer.decode(ii[am != 0])
            for ii, am in zip(input_ids, attention_mask)
        ]
        og_sentence, options = _extract_og_sentence_and_options(sentences)

        return cls(
            sentence=og_sentence,
            options=tuple(options),
            label_index=int(example_info.label),
            prediction_index=int(np.argmax(example_info.logits)),
        )


###############################################################################
###############################################################################


@dataclasses.dataclass
class OptionPresenceFilter(component_filtering.ComponentFilterAbc):
    """Filters out components who only indicate the presence of a particular option."""

    # Components with a fraction of top examples greater than or equal to this threshold
    # that contain the same option will be filtered out.
    fraction_threshold: float = 1.0

    def does_component_pass(
        self,
        tokenizer: PreTrainedTokenizer,
        component_index: int,
        top_examples: List['top_examples_common.TopExampleInfo'],
    ) -> bool:
        winogrande_examples = [
            WinograndeExample.from_top_example_info(tokenizer, example_info)
            for example_info in top_examples
        ]

        # Assumes that the two options for any given example will not be the same.
        counts = collections.defaultdict(lambda: 0)
        for we in winogrande_examples:
            for option in we.options:
                counts[option] += 1

        most_common_count = max(counts.values())

        return most_common_count / len(top_examples) < self.fraction_threshold
