"""Container utility class used for convenience"""
import dataclasses
import os
from typing import Callable, Dict, List, Mapping, Optional, Sequence

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import latex_util

from em.projects.anli import anli_misc1 as am

###############################################################################


@dataclasses.dataclass
class HansExample:
    premise: str
    hypothesis: str

    # 0 is entailment, 1 is non-entailment.
    label: int

    predicted_logits: np.ndarray

    idx: int

    heuristic: str
    template: str
    subcase: str

    index: Optional[int] = None

    def prediction(self) -> int:
        return np.argmax(self.predicted_logits)


###############################################################################


def _mnli_to_hans_logits(logits: np.ndarray):
    # Assumes MNLI logits are like [contradiction, entailment, neutral]
    entailment_logit = logits[:, 1]
    non_entailment_logit = np.maximum(logits[:, 0], logits[:, 2])
    return np.stack([entailment_logit, non_entailment_logit], axis=-1)


#####################################################################

@dataclasses.dataclass
class AnalysisContainer:
    pef: per_example.PerExampleFlatFishers
    nmfs: Sequence[nmf_common.NmfDecomposition]

    tokenizer: PreTrainedTokenizer

    examples: Sequence[Dict[str, np.ndarray]]

    @property
    def n_nmfs(self) -> int:
        return len(self.nmfs)

    @property
    def n_examples(self) -> int:
        return self.pef.input_ids.shape[0]

    def __post_init__(self):
        self.labels = self.pef.labels

        self.predicted_logits = _mnli_to_hans_logits(self.pef.predicted_logits)
        self.predictions = np.argmax(self.predicted_logits, axis=-1)

        self._old_examples = self.examples
        self.examples = self._make_examples(self._old_examples)

    def _make_examples(self, examples: Sequence[Dict[str, np.ndarray]]) -> List[HansExample]:
        examples = list(examples)
        assert len(examples) == self.n_examples

        ret = []
        for i in range(self.n_examples):
            ex = examples[i]

            # Quick sanity check.
            assert self.labels[i] == ex['label']
            # TODO: Maybe some more checking that the examples match correctly with
            # the PEF information?

            out = HansExample(
                premise=tf.compat.as_str(ex['premise']),
                hypothesis=tf.compat.as_str(ex['hypothesis']),
                label=self.labels[i],
                predicted_logits=self.predicted_logits[i],
                heuristic=tf.compat.as_str(ex['heuristic']),
                template=tf.compat.as_str(ex['template']),
                subcase=tf.compat.as_str(ex['subcase']),
                idx=ex['idx'],
                index=i,
            )
            ret.append(out)

        return ret

    ############################################################

    def get_correct_prediction_indicator(self) -> np.ndarray:
        return self.labels == self.predictions

    def get_incorrect_prediction_indicator(self) -> np.ndarray:
        return self.labels != self.predictions

    def get_entailment_label_indicator(self) -> np.ndarray:
        return self.labels == 0

    def get_non_entailment_label_indicator(self) -> np.ndarray:
        return self.labels == 1

    def get_indicator_by_example_fn(self, fn: Callable[[HansExample], bool]) -> np.ndarray:
        return np.array([fn(e) for e in self.examples], dtype=bool)

    def get_heuristic_indicator(self, heuristic: str) -> np.ndarray:
        return self.get_indicator_by_example_fn(lambda e: e.heuristic == heuristic)

    def get_lexical_overlap_indicator(self) -> np.ndarray:
        return self.get_heuristic_indicator('lexical_overlap')

    def get_subsequence_indicator(self) -> np.ndarray:
        return self.get_heuristic_indicator('subsequence')

    def get_constituent_indicator(self) -> np.ndarray:
        return self.get_heuristic_indicator('constituent')

    ############################################################

    def get_top_examples(self, nmf_index: int, component_index: int, n_examples: int) -> List[HansExample]:
        W = self.nmfs[nmf_index].W
        _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
        return [self.examples[i] for i in inds]

    ############################################################

    def _to_latex_label(self, label: int) -> str:
        if label:
            s = 'non-entailment'
            color = 'ForestGreen'
        else:
            s = 'entailment' + (4 * ' {}')
            color = 'BrickRed'
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(
        self,
        example: HansExample,
        nmf_index: int,
        component_index: int,
        options: Optional['LatexOptions'] = None,
    ) -> str:
        index = example.index

        s_coeff = f'{self.nmfs[nmf_index].W[index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(example.label)}',
            f'{"[PRED]"} {self._to_latex_label(example.prediction)}',
            f'{"[COEFF]"} {s_coeff}',
        ])

        # TODO: Add heuristic to this info.
        line11 = R' {} '.join([
            f'{"[TMPLT]"} {latex_util.rpad_curly(example.template, len("temp68"))}',
            f'{"[SUBCASE]"} {latex_util.escape(example.subcase)}',
        ])

        line2 = f'[P] {latex_util.escape(example.premise)}'
        line3 = f'[H] {latex_util.escape(example.hypothesis)}'
        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line11 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])

    ############################################################

    COMPONENTS_LATEX_FILE_START = am.PefNmfAnalysisContainer.COMPONENTS_LATEX_FILE_START
    COMPONENTS_LATEX_FILE_END = am.PefNmfAnalysisContainer.COMPONENTS_LATEX_FILE_END

    def make_all_components_latex_string(
        self,
        n_examples: int,
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        ret = []

        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            for component_index in range(self.nmfs[subset_index].W.shape[-1]):
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for x in self.get_top_examples(subset_index, component_index, n_examples):
                    ret.append(self.make_example_for_component_latex_string(x, subset_index, component_index))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)


#####################################################################

@dataclasses.dataclass
class LatexOptions:
    pass


###############################################################################


def load_analysis_container(
    pef_filepath: str,
    nmf_filepath: str,
    n_nmfs: int,
    n_pef_examples: Optional[int] = None,
    **kwargs
):
    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(pef_filepath),
        n_examples=n_pef_examples,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    nmfs = am._LazyNmfList(nmf_filepath, n_nmfs)

    return AnalysisContainer(pef=pef, nmfs=nmfs, **kwargs)
