"""Misc common tools and stuff for experimenting with ANLI."""
import collections
import concurrent
import dataclasses
import os
import re
import textwrap
import time
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple

import numpy as np
from scipy import special
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import latex_util
from em.util.color_util import cu

from . import nli_example

# typedefs
NliExample = nli_example.NliExample


_LABEL_CHAR_ABBREVS = ("e", "n", "c")

_LABEL_NAMES = ('entailment', 'neutral', 'contradiction')
_PADDED_LABEL_NAMES = (
    'entailment   ',
    'neutral      ',
    'contradiction',
)
_PADDED_LABEL_NAMES_LATEX = (
    'entailment {} {} {}',
    'neutral {} {} {} {} {} {}',
    'contradiction',
)

_CHAR_TO_LABEL_NAME = {
    "-": 'n/a',
    "e": 'entailment',
    "n": 'neutral',
    "c": 'contradiction',
}
_CHAR_TO_PADDED_LABEL_NAME = {
    '-': 'n/a          ',
    'e': 'entailment   ',
    'n': 'neutral      ',
    'c': 'contradiction',
}
_CHAR_TO_PADDED_LABEL_NAME_LATEX = {
    "-": 'n/a {} {} {} {} {} {} {} {} {} {}',
    "e": 'entailment {} {} {}',
    "n": 'neutral {} {} {} {} {} {}',
    "c": 'contradiction',
}


###############################################################################


_TAB = 4 * ' '


def _capitalize_first_letter(s: str) -> str:
    # NOTE: Won't capitalize anything if the string has leading whitespace.
    if s == '':
        return s
    return f'{s[0].upper()}{s[1:]}'


###############################################################################


def _binomial_pmf(n, k, p):
    return special.binom(n, k) * p**k * (1 - p) ** (n - k)


def estimate_selectivity_for_indicator(
    W: np.ndarray,
    indicator: np.ndarray,
    k: int
) -> Tuple[np.ndarray, np.ndarray]:
    # W.shape = [n_examples, n_components]
    # indicator.shape = [n_examples]

    inds = np.argsort(-W.T)[:, :k]
    counts = indicator[inds].astype(np.int32).sum(axis=-1)
    fractions = counts.astype(np.float64) / k

    p_indicator = indicator.astype(np.float64).mean()
    pmf = _binomial_pmf(k, np.arange(k + 1), p_indicator)
    p_values = np.array([pmf[c:].sum() for c in counts])

    return fractions, p_values


###############################################################################
###############################################################################


def _load_nmf_decomp(filepath: str, subset_index: int):
    assert filepath.endswith('.h5')
    filepath = f"{filepath[:-3]}.ssi{subset_index}.h5"
    decomp = nmf_common.NmfDecomposition.load(os.path.expanduser(filepath))
    decomp.normalize_components_to_unit_norm()
    # # Not getting full H is intended.
    # decomp.H = decomp.get_full_H()
    return decomp


class _LazyNmfList(collections.Sequence):
    """NOTE: Does not support many of the methods of the built-in Python list object."""

    def __init__(self, filepath: str, n_nmfs: int):
        assert filepath.endswith('.h5')
        self._filepath = filepath
        self._n_nmfs = n_nmfs
        self._cache = {}

    def _get_index(self, index: int) -> nmf_common.NmfDecomposition:
        if index not in self._cache:
            self._cache[index] = self._get_index_no_cache(index)
        return self._cache[index]

    def _get_index_no_cache(self, index: int) -> nmf_common.NmfDecomposition:
        return _load_nmf_decomp(self._filepath, index)

    def __len__(self):
        return self._n_nmfs

    def __getitem__(self, key: int):
        # This does out of bounds checking and handles all the sorts of
        # fancy indexing. I might want to handle numpy-style indexing in the
        # future as well, but that is not currently supported.
        indicies = range(self._n_nmfs)[key]

        if isinstance(indicies, int):
            return self._get_index(indicies)

        return [self._get_index(i) for i in indicies]

    def force_load_all(self):
        # NOTE: Using multi-threading here appears to be only slightly faster than
        # not using multi-threading.
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self._get_index, i) for i in range(self._n_nmfs)]
            for future in concurrent.futures.as_completed(futures):
                future.result()


def load_pef_nmf_analysis_container(
    pef_filepath: str,
    nmf_filepath: str,
    n_nmfs: int,
    n_pef_examples: Optional[int] = None,
    **kwargs
):
    print('Starting to load saved per-example Fishers.')
    start = time.time()
    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(pef_filepath),
        n_examples=n_pef_examples,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    print('Load saved per-example Fishers time: ', time.time() - start)

    nmfs = _LazyNmfList(nmf_filepath, n_nmfs)

    return PefNmfAnalysisContainer(pef=pef, nmfs=nmfs, **kwargs)


###############################################################################


_LABEL_ABBREV_TO_COLOR_FN = {
    '-': cu.dlw,
    'e': cu.hlg,
    'n': cu.hlb,
    'c': cu.hlr,
}


# Requires `\usepackage[dvipsnames]{xcolor}`
_LABEL_ABBREV_TO_LATEX_COLOR = {
    '-': 'lightgray',
    'n': 'RoyalBlue',
    'e': 'ForestGreen',
    'c': 'BrickRed',
}


@dataclasses.dataclass
class PefNmfAnalysisContainer:
    pef: per_example.PerExampleFlatFishers
    nmfs: Sequence[nmf_common.NmfDecomposition]

    tokenizer: PreTrainedTokenizer

    shift_labels: bool

    # Boolean array of length n_examples indicating whether
    # the example has a label or not. Will be set to examples with a label
    # of -1 otherwise.
    unlabeled_indicator: Optional[np.ndarray] = None

    @property
    def n_nmfs(self) -> int:
        return len(self.nmfs)

    def __post_init__(self):
        if self.unlabeled_indicator is None:
            self.unlabeled_indicator = self.pef.labels == -1

        self.predicted_logits = self.pef.predicted_logits
        self.predictions = np.argmax(self.predicted_logits, axis=-1)

        if self.shift_labels:
            shift_tuple = lambda t: tuple(t[2:] + t[:2])
            self.labels = (self.pef.labels + 1) % 3
            self._label_char_abbrevs = shift_tuple(_LABEL_CHAR_ABBREVS)
            self._label_names = shift_tuple(_LABEL_NAMES)
            self._padded_label_names = shift_tuple(_PADDED_LABEL_NAMES)
            self._padded_label_names_latex = shift_tuple(_PADDED_LABEL_NAMES_LATEX)

        else:
            self.labels = self.pef.labels
            self._label_char_abbrevs = tuple(_LABEL_CHAR_ABBREVS)
            self._label_names = tuple(_LABEL_NAMES)
            self._padded_label_names = tuple(_PADDED_LABEL_NAMES)
            self._padded_label_names_latex = tuple(_PADDED_LABEL_NAMES_LATEX)

        self.examples = self._make_nli_examples()

    def _make_nli_examples(self):
        r_cls_token = re.escape(self.tokenizer.cls_token)
        r_sep_token = re.escape(self.tokenizer.sep_token)
        example_regex = rf'^{r_cls_token}(.+){r_sep_token}(.+){r_sep_token}$'

        examples = []
        for i, input_ids in enumerate(self.pef.input_ids):
            example = self.tokenizer.decode(input_ids)
            example = example.replace(self.tokenizer.pad_token, '')
            example = example.strip()

            match = re.search(example_regex, example)
            premise = match.group(1).strip()
            hypothesis = match.group(2).strip()

            label = self.labels[i]
            predicted_logits = self.predicted_logits[i]
            pred = np.argmax(predicted_logits)

            if self.unlabeled_indicator[i]:
                label_char = '-'
            else:
                label_char = self._label_char_abbrevs[label]

            pred_char = self._label_char_abbrevs[pred]

            ex = NliExample(
                index=i,
                premise=premise,
                hypothesis=hypothesis,
                label_char=label_char,
                prediction_char=pred_char,
                predicted_logits=predicted_logits,
            )
            examples.append(ex)

        return examples

    ############################################################

    def get_indicator_by_example_fn(self, fn: Callable[[NliExample], bool]) -> np.ndarray:
        return np.array([fn(e) for e in self.examples], dtype=np.bool)

    # TODO: These won't be accurate when we have unlabelled examples.

    def get_correct_prediction_indicator(self) -> np.ndarray:
        return self.labels == self.predictions

    def get_incorrect_prediction_indicator(self) -> np.ndarray:
        return self.labels != self.predictions

    ############################################################

    def estimate_selectivity_for_indicator(
            self, nmf_index: int, indicator: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
        W = self.nmfs[nmf_index].W
        return estimate_selectivity_for_indicator(W, indicator, k=k)

    ############################################################

    def get_top_examples(self, nmf_index: int, component_index: int, n_examples: int) -> List[NliExample]:
        W = self.nmfs[nmf_index].W
        _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
        return [self.examples[i] for i in inds]

    def get_top_examples_based_on_relative_coefficient(
        self,
        nmf_index: int,
        component_index: int,
        factor: float,
        max_examples: Optional[int] = None,
    ) -> List[NliExample]:
        assert 0 <= factor <= 1

        all_coeffs = self.nmfs[nmf_index].W[:, component_index]
        inds = np.argsort(-all_coeffs)

        top_coeff = all_coeffs[inds[0]]

        examples = []
        for i, ind in enumerate(inds):
            if top_coeff * factor <= all_coeffs[ind]:
                examples.append(self.examples[ind])
            else:
                break
            if max_examples is not None and i + 1 >= max_examples:
                break

        return examples

    ############################################################

    def print_example_for_component(self, example: NliExample, nmf_index: int, component_index: int):
        index = example.index

        premise = _capitalize_first_letter(example.premise)
        hypothesis = _capitalize_first_letter(example.hypothesis)

        s_label = self._padded_label_names[self.labels[index]]
        s_label = _LABEL_ABBREV_TO_COLOR_FN[example.label_char](s_label)

        s_pred = self._padded_label_names[self.predictions[index]]
        s_pred = _LABEL_ABBREV_TO_COLOR_FN[example.prediction_char](s_pred)

        s_coeff = cu.hy(f'{self.nmfs[nmf_index].W[index, component_index]:.4f}')

        print(f'Example {index}:')

        print(_TAB + ('   '.join([
            f'{cu.dlw("[LABEL]")} {s_label}',
            f'{cu.dlw("[PRED]")} {s_pred}',
            f'{cu.dlw("[COEFF]")} {s_coeff}',
        ])))

        print(f'{_TAB}{cu.dlb("[PREMISE]")} {premise}')
        print(f'{_TAB}{cu.dlc("[HYPOTHESIS]")} {hypothesis}')

    def print_top_examples(self, nmf_index: int, component_index: int, n_examples: int):
        for x in self.get_top_examples(nmf_index, component_index, n_examples):
            self.print_example_for_component(x, nmf_index, component_index)
            print('')

    ###################################

    def _to_latex_label(self, label_char: str) -> str:
        # The bold requires `\usepackage{bold-extra}`
        s = _CHAR_TO_PADDED_LABEL_NAME_LATEX[label_char]
        color = _LABEL_ABBREV_TO_LATEX_COLOR[label_char]
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(self, example: NliExample, nmf_index: int, component_index: int) -> str:
        # TODO: Replace %, $ with backslashed versions, “, ’ to ", ', – to -
        index = example.index

        premise = _capitalize_first_letter(example.premise)
        hypothesis = _capitalize_first_letter(example.hypothesis)

        s_coeff = f'{self.nmfs[nmf_index].W[index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(example.label_char)}',
            f'{"[PRED]"} {self._to_latex_label(example.prediction_char)}',
            f'{"[COEFF]"} {s_coeff}',
        ])

        line2 = f'[P] {latex_util.escape(premise)}'
        line3 = f'[H] {latex_util.escape(hypothesis)}'
        return "\n".join([
            R'\noindent\texttt{' + line1 + R'\vspace{1mm} \\',
            R'{}' + line2 + R'\vspace{1mm} \\',
            R'{}' + line3 + R'\vspace{2mm}\\',
            R'}',
        ])

    def print_example_for_component_latex(self, example: NliExample, nmf_index: int, component_index: int):
        print(self.make_example_for_component_latex_string(example, nmf_index, component_index))

    def print_top_examples_latex(self, nmf_index: int, component_index: int, n_examples: int):
        for x in self.get_top_examples(nmf_index, component_index, n_examples):
            self.print_example_for_component_latex(x, nmf_index, component_index)
            print('')

    ##################################################

    COMPONENTS_LATEX_FILE_START = textwrap.dedent(R"""% Please use XeLaTex to handle unicode properly.
    \documentclass[11pt]{article}

    \usepackage[margin=1in]{geometry} 
    \usepackage[dvipsnames]{xcolor}
    \usepackage{bold-extra}

    \begin{document}

    \section{Component Top Examples}
    \textit{All indices used here are 0-based unless stated otherwise. Note that section indices will be 1-based though.}

    """)

    COMPONENTS_LATEX_FILE_END = "\n\\end{document}"

    def make_all_components_latex_string(
        self,
        n_examples: int,
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        ret = []

        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            for component_index in range(self.nmfs[subset_index].W.shape[-1]):
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for x in self.get_top_examples(subset_index, component_index, n_examples):
                    ret.append(self.make_example_for_component_latex_string(x, subset_index, component_index))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

    def make_latex_string_for_some_components(
        self,
        component_indices_per_subset: Sequence[Sequence[int]],
        n_examples: int,
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        
        ret = []

        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            for component_index in component_indices_per_subset[subset_index]:
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for x in self.get_top_examples(subset_index, component_index, n_examples):
                    ret.append(self.make_example_for_component_latex_string(x, subset_index, component_index))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

    def make_latex_string_for_some_components_and_examples(
        self,
        example_indices_per_component_per_subset: Sequence[Dict[int, Sequence[int]]],
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        
        ret = []

        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            components = example_indices_per_component_per_subset[subset_index]

            for component_index in sorted(components.keys()):
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for example_index in components[component_index]:
                    x = self.examples[example_index]
                    ret.append(self.make_example_for_component_latex_string(x, subset_index, component_index))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

###############################################################################


def make_per_sub_block_subset_names(n_nmfs):
    ret = []
    assert n_nmfs % 2 == 0
    for i in range(n_nmfs // 2):
        ret.append(f'Layer {i} Attention Sub-Block')
        ret.append(f'Layer {i} Feedforward Sub-Block')
    return ret
