"""Analysis related to SST2 examples."""
import dataclasses
import re
from typing import Any, Dict

import numpy as np
from transformers import PreTrainedTokenizer

from npeff_torch.examination.top_examples import top_examples_common


###############################################################################


@dataclasses.dataclass
class Sst2Example:
    sentence: str
    label: int
    prediction: int

    #######################################################

    def to_example_dict(self) -> Dict[str, Any]:
        return {
            'sentence': self.sentence,
            'label': self.label,
        }

    #######################################################

    @classmethod
    def from_top_example_info(
        cls,
        tokenizer: PreTrainedTokenizer,
        example_info: 'top_examples_common.TopExampleInfo',
    ) -> 'Sst2Example':
        input_ids = example_info.example['input_ids']
        attention_mask = example_info.example['attention_mask']

        context = tokenizer.decode(input_ids[attention_mask != 0])

        match = re.search(
            r'^Review: (.+)\nSentiment:$',
            context)

        if match:
            sentence = match.group(1)
        else:
            sentence = 'ERROR'

        return cls(
            sentence=sentence,
            label=int(example_info.label),
            prediction=int(np.argmax(example_info.logits)),
        )
