"""Context for decomposition of QQP components."""
import dataclasses
import itertools
import os
import re
from typing import Optional

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em import datasets as em_datasets
from em.fishers import per_example
from em.tools.nmf import lrm_npeff

from em.projects.pi import qqp_components_context as QCC


###############################################################################

@dataclasses.dataclass
class QqpExample:
    # TODO: Add more info.
    index: int

    sentence1: str
    sentence2: str

    label: int


@dataclasses.dataclass
class QqpContext:
    """Context for LRM-NPEFF decomposition over QQP examples."""
    split: str

    tokenizer: PreTrainedTokenizer

    nmf: lrm_npeff.LrmNpeffDecomposition

    # If we are not using the examples, then this is faster.
    load_examples: bool = True

    ds_name: str = 'glue/qqp'

    def __post_init__(self):
        self.n_examples = self.nmf.W.shape[0]
        if self.load_examples:
            self.examples = self._load_examples()
        else:
            self.examples = None

    def _load_examples(self):
        # Assumes the examples are the first in the split.
        ds = em_datasets.load(self.ds_name, split=self.split, sequence_length=128, tokenizer=self.tokenizer)
        examples = []
        for i, (x, y) in itertools.islice(enumerate(ds.as_numpy_iterator()), self.n_examples):
            examples.append(self._make_example(x, y, i))
        return examples

    def _make_example(self, x, y, index) -> QqpExample:
        r_cls_token = re.escape(self.tokenizer.cls_token)
        r_sep_token = re.escape(self.tokenizer.sep_token)
        example_regex = rf'^{r_cls_token}(.+){r_sep_token}(.+){r_sep_token}$'

        example = self.tokenizer.decode(x['input_ids'])
        example = example.replace(self.tokenizer.pad_token, '')
        example = example.strip()

        match = re.search(example_regex, example)
        sentence1 = match.group(1).strip()
        sentence2 = match.group(2).strip()

        return QqpExample(index=index, sentence1=sentence1, sentence2=sentence2, label=y)

    def get_top_examples(self, component_index: int, n_examples: int):
        top_inds = np.argsort(-self.nmf.W[:, component_index])[:n_examples]
        return [self.examples[i] for i in top_inds]

    def create_eval_ctx_given_logits(self, logits: np.ndarray):
        assert logits.shape[0] == self.n_examples

        ds = em_datasets.load(self.ds_name, split=self.split, sequence_length=128, tokenizer=self.tokenizer)

        for x, y in ds.batch(self.n_examples).as_numpy_iterator():
            break

        return QCC.EvaluationContext2(
            all_examples=(x, y),
            og_logits=logits,
        )
