"""Misc elementary antiderivative stuff."""
import collections
import dataclasses
import textwrap
from typing import List, Mapping, Optional, Sequence

import numpy as np
from scipy import special

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

import tensorflow as tf
from transformers import PreTrainedTokenizer

from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import latex_util
from em.util.color_util import cu


_LABEL_TO_PADDED_LABEL_NAME_LATEX = {
    0: 'no {}',
    1: 'yes',
}


# Requires `\usepackage[dvipsnames]{xcolor}`
_LABEL_TO_LATEX_COLOR = {
    0: 'BrickRed',
    1: 'ForestGreen',
}


@dataclasses.dataclass
class EadAnalysisContainer:
    pef: per_example.PerExampleFlatFishers
    nmfs: Sequence[nmf_common.NmfDecomposition]

    tokenizer: PreTrainedTokenizer

    def __post_init__(self):
        self.x = sp.Symbol('x', real=True, nonzero=True)

        self.unlabeled_indicator = self.pef.labels == -1

        self.labels = self.pef.labels
        self.predictions = np.argmax(self.pef.predicted_logits, axis=-1)

        self._example_index_to_latex_string = {}

    @property
    def n_nmfs(self) -> int:
        return len(self.nmfs)

    ############################################################

    def get_correct_prediction_indicator(self) -> np.ndarray:
        return self.labels == self.predictions

    def get_incorrect_prediction_indicator(self) -> np.ndarray:
        return self.labels != self.predictions

    ############################################################

    def get_top_example_indices(self, nmf_index: int, component_index: int, n_examples: int) -> np.ndarray:
        W = self.nmfs[nmf_index].W
        _, inds = tf.math.top_k(W[:, component_index], k=n_examples)
        return inds.numpy().astype(np.int32)

    def get_top_examples_indices_based_on_relative_coefficient(
        self,
        nmf_index: int,
        component_index: int,
        factor: float,
        max_examples: Optional[int] = None,
    ) -> np.ndarray:
        assert 0 <= factor <= 1

        all_coeffs = self.nmfs[nmf_index].W[:, component_index]
        inds = np.argsort(-all_coeffs)

        top_coeff = all_coeffs[inds[0]]

        examples = []
        for i, ind in enumerate(inds):
            if top_coeff * factor <= all_coeffs[ind]:
                examples.append(ind)
            else:
                break
            if max_examples is not None and i + 1 >= max_examples:
                break

        return np.array(examples, dtype=np.int32)

    def get_top_examples_based_on_relative_coefficient(
        self,
        nmf_index: int,
        component_index: int,
        factor: float,
        max_examples: Optional[int] = None,
    ):
        # NOTE: THis is just a quikc hack.
        inds = self.get_top_examples_indices_based_on_relative_coefficient(
            nmf_index, component_index, factor, max_examples)
        FakeExample = collections.namedtuple('FakeExample', ['index'])
        return [FakeExample(index=i) for i in inds]

    ############################################################

    def get_example_as_string(self, example_index: int) -> str:
        example = self.tokenizer.decode(self.pef.input_ids[example_index])
        example = example.replace(self.tokenizer.pad_token, '')
        example = example.replace(self.tokenizer.cls_token, '')
        example = example.replace(self.tokenizer.sep_token, '')
        example = example.replace(' ', '')
        return example

    def get_example_as_latex_string(self, example_index: int) -> str:
        if example_index not in self._example_index_to_latex_string:
            example = self.get_example_as_string(example_index)
            try:
                expr = parse_expr(example, evaluate=True, local_dict={self.x.name: self.x})
                self._example_index_to_latex_string[example_index] = sp.latex(expr)
            except Exception:
                self._example_index_to_latex_string[example_index] = ' '
        return self._example_index_to_latex_string[example_index]

    ############################################################

    def print_example_for_component(self, nmf_index: int, component_index: int, example_index: int):
        coeff = self.nmfs[nmf_index].W[example_index, component_index]

        label = self.pef.labels[example_index]
        predicted_logits = self.pef.predicted_logits[example_index]
        pred = np.argmax(predicted_logits)

        example = self.get_example_as_string(example_index)
        print(f'{coeff:.4f}', label, pred, example)

    ############################################################

    def _to_latex_label(self, label: int) -> str:
        # The bold requires `\usepackage{bold-extra}`
        s = _LABEL_TO_PADDED_LABEL_NAME_LATEX[label]
        color = _LABEL_TO_LATEX_COLOR[label]
        return R'{\color{' + color + R'}\textbf{' + s + R'}}'

    def make_example_for_component_latex_string(self, nmf_index: int, component_index: int, example_index: int):
        label = self.pef.labels[example_index]
        predicted_logits = self.pef.predicted_logits[example_index]
        pyx = special.softmax(predicted_logits)
        pred = np.argmax(predicted_logits)

        s_coeff = f'{self.nmfs[nmf_index].W[example_index, component_index]:.4f}'

        line1 = R' {} '.join([
            f'{"[LABEL]"} {self._to_latex_label(label)}',
            f'{"[PRED]"} {self._to_latex_label(pred)}',
            f'{"[COEFF]"} {s_coeff}',
        ])
        line2 = fR'P(Y|X): \{{{pyx[0]:.3f}, {pyx[1]:.3f}\}}'

        # line3 = latex_util.escape(self.get_example_as_string(example_index))
        line3 = f'$${self.get_example_as_latex_string(example_index)}$$'

        return "\n".join([
            R'\noindent\texttt{' + line1 + R' \\',
            line2 + R' \\',
            R'{}' + line3 + R'\vspace{2mm} \\',
            R'}',
        ])

    ############################################################

    def selection_parameters_to_latex_string(self, selection_parameters) -> str:
        p = selection_parameters
        assert p.max_examples is None

        cf = f'{p.coeff_factor:.3f}'
        ft = f'{p.frac_threshold:.3f}'
        pvt = f'{p.p_value_threshold:.4f}'

        return '\n\n'.join([
            R"\paragraph{Selection Parameters}",
            f'Coefficient factor: {cf}',
            f'Fraction threshold: {ft}',
            f'P-value threshold: {pvt}',
        ])

    ############################################################

    _LATEX_FILE_HEADER = textwrap.dedent(R"""
    % Please use XeLaTex to handle unicode properly.
    \documentclass[11pt]{article}

    \usepackage[margin=1in]{geometry} 
    \usepackage[dvipsnames]{xcolor}
    \usepackage{bold-extra}
    \usepackage{amsmath}
    \usepackage{amsfonts}
    
    \begin{document}

    """)

    _COMPONENTS_SECTION_HEADER = textwrap.dedent(R"""

    \section{Component Top Examples}
    \textit{All indices used here are 0-based unless stated otherwise. Note that section indices will be 1-based though.}

    """)

    _LATEX_FILE_END = "\n\\end{document}"

    def wrap_latex_contents(self, contents: str, preface: str = '') -> str:
        return '\n'.join([
            self._LATEX_FILE_HEADER,
            preface,
            self._COMPONENTS_SECTION_HEADER,
            contents,
            self._LATEX_FILE_END
        ])

    def make_all_components_latex_string(
        self,
        n_examples: int,
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        ret = []
        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            for component_index in range(self.nmfs[subset_index].W.shape[-1]):
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for ind in self.get_top_example_indices(subset_index, component_index, n_examples):
                    ret.append(self.make_example_for_component_latex_string(subset_index, component_index, ind))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

    def make_filtered_components_latex_string(
        self,
        filter_fn,
        n_examples: int,
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        ret = []
        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            for component_index in range(self.nmfs[subset_index].W.shape[-1]):
                if not filter_fn(self, subset_index, component_index):
                    continue
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for ind in self.get_top_example_indices(subset_index, component_index, n_examples):
                    ret.append(self.make_example_for_component_latex_string(subset_index, component_index, ind))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

    def make_latex_string_for_some_components_and_examples(
        self,
        example_indices_per_component_per_subset: Sequence[Mapping[int, Sequence[int]]],
        components_fontsize: str = 'footnotesize',
        nmf_names: Optional[Mapping[int, str]] = None,
    ) -> str:
        ret = []
        for subset_index in range(self.n_nmfs):
            if nmf_names is None:
                subsection_name = f'Subset {subset_index}'
            else:
                subsection_name = nmf_names[subset_index]

            ret.append(R'\subsection{' + subsection_name + R'}')

            components = example_indices_per_component_per_subset[subset_index]

            for component_index in sorted(components.keys()):
                ret.append(R'\subsubsection{Component ' + str(component_index) + R'}')
                ret.append(R'\begin{' + components_fontsize + R'}')

                for example_index in components[component_index]:
                    ret.append(self.make_example_for_component_latex_string(subset_index, component_index, example_index))
                    ret.append('')

                ret.append(R'\end{' + components_fontsize + R'}')
                ret.append('')
            
        return '\n'.join(ret)

    ###########################################################################

    def compute_mass_fraction_by_subset(self, example_indices=None):
        if example_indices is None:
            example_indices = range(self.pef.input_ids.shape[0])

        example_indices = np.array(example_indices, dtype=np.int32)

        masses_by_subset = []
        for nmf in self.nmfs:
            coeffs = nmf.W[example_indices]
            # We are going with L1 fraction rather than something L2-based.
            comp_masses = np.abs(nmf.H).sum(axis=-1)
            subset_mass = (coeffs * comp_masses).sum()
            masses_by_subset.append(subset_mass)

        ret = np.array(masses_by_subset, dtype=np.float64)
        ret /= np.sum(ret)

        return ret
