"""Context for decomposition of SNLI components."""
import dataclasses
import itertools
import os
import re
from typing import Optional

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em import datasets as em_datasets
from em.fishers import per_example
from em.tools.nmf import lrm_npeff

from em.projects.pi import qqp_components_context as QCC


###############################################################################

@dataclasses.dataclass
class SnliExample:
    # TODO: Add more info.
    index: int

    premise: str
    hypothesis: str

    label: int


class ExamplesList(list):
    def print_all(self):
        for ex in self:
            print(f'PREMISE: {ex.premise}')
            print(f'HYPOTHESIS: {ex.hypothesis}')
            print('')

###############################################################################


@dataclasses.dataclass
class SnliContext:
    """Context for LRM-NPEFF decomposition over SNLI examples."""
    split: str

    tokenizer: PreTrainedTokenizer

    nmf: lrm_npeff.LrmNpeffDecomposition

    # If we are not using the examples, then this is faster.
    load_examples: bool = True

    def __post_init__(self):
        self.n_examples = self.nmf.W.shape[0]
        if self.load_examples:
            self.examples = self._load_snli_examples()
        else:
            self.examples = None

    def _load_snli_examples(self):
        # Assumes the examples are the first in the split.
        ds = em_datasets.load('snli/default', split=self.split, sequence_length=128, tokenizer=self.tokenizer)
        examples = []
        for i, (x, y) in itertools.islice(enumerate(ds.as_numpy_iterator()), self.n_examples):
            examples.append(self._make_snli_example(x, y, i))
        return examples

    def _make_snli_example(self, x, y, index) -> SnliExample:
        r_cls_token = re.escape(self.tokenizer.cls_token)
        r_sep_token = re.escape(self.tokenizer.sep_token)
        example_regex = rf'^{r_cls_token}(.+){r_sep_token}(.+){r_sep_token}$'

        example = self.tokenizer.decode(x['input_ids'])
        example = example.replace(self.tokenizer.pad_token, '')
        example = example.strip()

        match = re.search(example_regex, example)
        premise = match.group(1).strip()
        hypothesis = match.group(2).strip()

        label = (y + 1) % 3

        return SnliExample(index=index, premise=premise, hypothesis=hypothesis, label=label)
    
    def get_top_examples(self, component_index: int, n_examples: int) -> ExamplesList:
        top_inds = np.argsort(-self.nmf.W[:, component_index])[:n_examples]
        return ExamplesList([self.examples[i] for i in top_inds])

    def create_eval_ctx(self, model, n_examples: Optional[int] = None) -> QCC.EvaluationContext2:
        # NOTE: Maybe not the best permanent place for this function, but
        # whatever for the sake of expediency.
        ds = em_datasets.load('snli/default', split=self.split, sequence_length=128, tokenizer=self.tokenizer)
        ds = ds.take(self.n_examples if n_examples is None else n_examples)

        ds = em_datasets.glue.fix_text_attack_mnli_labeling(ds)

        return QCC.EvaluationContext2.create_from_ds(
            ds=ds.cache(),
            model=model,
            special_processing='HF_MNLI',
        )

    def create_eval_ctx_from_pefs_file(self, pefs_filepath: str):
        pef = per_example.PerExampleFlatFishers.load(
            os.path.expanduser(pefs_filepath),
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )
        return QCC.EvaluationContext2.create_from_pefs(
            pef=pef,
            tokenizer=self.tokenizer,
            special_processing='HF_MNLI',
        )
