"""Generates a .tex file containing the top examples information for NPEFF components."""
import json
import os
import pydoc
from typing import List, Optional, Sequence

from absl import app
from absl import flags

import h5py
import numpy as np

from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


# flags.DEFINE_string('pefs_filepath', None, '')
flags.DEFINE_list('pef_filepaths', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('npeff_filepath', None, '')

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_string('latex_generator_cls_path', None, '')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of generators will require some flags to be passed in this manner.
flags.DEFINE_string('latex_generator_kwargs', None, '')


flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_list('component_indices', None, 'Leave unset to compute top examples for all components.')

flags.DEFINE_string('components_fontsize', 'footnotesize', '')

flags.DEFINE_bool('normalize_components_to_unit_norm', False, 'Normalize the G such that the rank-1 basis PSD matrices have unit frobenius norm.')


# Component filtering stuff.
flags.DEFINE_list('component_filter_cls_paths', [],
                  'List of pydoc.locate-compatible class paths used to create component filters. The '
                  'classes should be subclasses of component_filtering.ComponentFilterAbc. If more than one '
                  'class path is provided, a filter will be created for each entry. The actual filter used '
                  'will be the logical and of these filters.')
flags.DEFINE_string('component_filter_kwargs', '', 
                    'JSON list of JSON kwargs for the component filters. Must be provided if --component_filter_cls_paths '
                    'is provided. Is a parallel list to --component_filter_cls_paths, so must have the same number of entries.')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_coeffs(filepath: str, eps=1e-12) -> np.ndarray:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
        if FLAGS.normalize_components_to_unit_norm:
            G = hdf5_utils.load_h5_ds(f['data/G'])

    if FLAGS.normalize_components_to_unit_norm:
        norms = np.sum(G**2, axis=-1, keepdims=True)
        W *= norms.T + eps

    return W


def _get_component_indices(n_components: int) -> Sequence[int]:
    if FLAGS.component_indices:
        return [int(ci) for ci in FLAGS.component_indices]
    else:
        return range(n_components)


def _get_component_filter():
    if len(FLAGS.component_filter_cls_paths) == 0:
        return None

    component_filter_kwargs = json.loads(FLAGS.component_filter_kwargs)
    assert len(component_filter_kwargs) == len(FLAGS.component_filter_cls_paths)

    filters = []
    for component_filter_cls_path, kw in zip(FLAGS.component_filter_cls_paths, component_filter_kwargs):
        ComponentFilter = pydoc.locate(component_filter_cls_path)
        filters.append(ComponentFilter(**kw))

    return component_filtering.FiltersLogicalAnd(filters=filters)

###############################################################################


def main(_):
    assert FLAGS.n_top_examples > 0

    LatexGenerator = pydoc.locate(FLAGS.latex_generator_cls_path)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    component_filter = _get_component_filter()

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    coefficients = _read_coeffs(FLAGS.npeff_filepath)
    coeff_n_examples, n_components = coefficients.shape

    if coeff_n_examples != pei_n_examples:
        raise ValueError

    # # If the number of examples in the decomposition is less than the number of examples in
    # # the PEFs, then assume that the NPEFF was run on a "prefix" of the PEF examples.
    # if coeff_n_examples < pei_n_examples:
    #     pef_extra_infos = pef_extra_infos.get_slice(0, coeff_n_examples)
    # elif coeff_n_examples > pei_n_examples:
    #     raise ValueError

    component_indices = _get_component_indices(n_components)

    top_examples_reader = top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
        coefficients=coefficients,
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )

    latex_generator = LatexGenerator.create(
        top_examples_reader=top_examples_reader,
        tokenizer=tokenizer,
        n_top_examples=FLAGS.n_top_examples,
        components_fontsize=FLAGS.components_fontsize,
        component_filter=component_filter,
        **_read_flag_kwargs(FLAGS.latex_generator_kwargs),
    )

    latex_content = latex_generator.generate_components_latex(component_indices)

    with open(os.path.expanduser(FLAGS.output_filepath), 'wt') as f:
        f.write(latex_content)


if __name__ == "__main__":
    app.run(main)
