"""Generates a .tex file containing the top examples information for k-means clusters."""
import json
import os
import pydoc
from typing import List, Optional, Sequence, Tuple

from absl import app
from absl import flags

import h5py
import numpy as np

from npeff_torch.examination.top_examples import top_examples_from_clusters
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_list('gradient_filepaths', None, 'The gradient files used to compute the clusters.')
flags.DEFINE_list('n_examples_per_gradient_file', None,
                  "Comma-separated list of integers indicating the number of examples to use from each gradient file. "
                  "If provided, the list must be the same length as the --gradient_filepaths list. "
                  "Leave empty to use all examples from all files. "
                  "Use a value of -1 for a particular file to use all examples from that particular file.")

flags.DEFINE_string('clusters_filepath', None, '')

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_string('latex_generator_cls_path', None, '')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of generators will require some flags to be passed in this manner.
flags.DEFINE_string('latex_generator_kwargs', None, '')


flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_list('component_indices', None, 'Leave unset to compute top examples for all components.')

flags.DEFINE_string('components_fontsize', 'footnotesize', '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _read_n_examples_per_gradient_file_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _validate_n_examples_per_gradient_file(gradient_filepaths: List[str], n_examples_per_gradient_file: Optional[List[Optional[int]]]):
    if n_examples_per_gradient_file is not None and len(gradient_filepaths) != len(n_examples_per_gradient_file):
        raise ValueError('If --n_examples_per_gradient_file is provided, its number of entries must match that of the --gradient_filepaths flag.')


def _read_pef_extra_infos(gradient_filepaths: List[str], n_examples_per_gradient_file: Optional[List[Optional[int]]]):
    # Assumes n_examples_per_gradient_file is valid.

    pef_extra_infos = []
    for i, gradient_filepath in enumerate(gradient_filepaths):
        pei = pef_format_common.PefExtraInfos.read_from_file(gradient_filepath)
        pei_n_examples = pei.n_examples
        assert pei_n_examples is not None

        n_examples = n_examples_per_gradient_file[i] if n_examples_per_gradient_file is not None else None
        if n_examples is not None:
            if n_examples < pei_n_examples:
                pei = pei.get_slice(0, n_examples)
            elif n_examples > pei_n_examples:
                raise ValueError

        pef_extra_infos.append(pei)

    return pef_format_common.PefExtraInfos.concat(pef_extra_infos, ensure_same_fields=True)


def _read_clusters(filepath: str) -> Tuple[np.ndarray, np.ndarray]:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        cluster_assignments = hdf5_utils.load_h5_ds(f['data/cluster_assignments'])
        centroid_distances = hdf5_utils.load_h5_ds(f['data/centroid_distances'])
        n_components = f['data/centroids'].shape[0]
    return cluster_assignments, centroid_distances, n_components


def _get_component_indices(n_components: int) -> Sequence[int]:
    if FLAGS.component_indices:
        return [int(ci) for ci in FLAGS.component_indices]
    else:
        return range(n_components)


###############################################################################


def main(_):
    assert FLAGS.n_top_examples > 0

    gradient_filepaths = FLAGS.gradient_filepaths
    n_examples_per_gradient_file = _read_n_examples_per_gradient_file_flag(FLAGS.n_examples_per_gradient_file)
    _validate_n_examples_per_gradient_file(gradient_filepaths, n_examples_per_gradient_file)

    LatexGenerator = pydoc.locate(FLAGS.latex_generator_cls_path)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    pef_extra_infos = _read_pef_extra_infos(gradient_filepaths, n_examples_per_gradient_file)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    cluster_assignments, centroid_distances, n_components = _read_clusters(FLAGS.clusters_filepath)
    coeff_n_examples, = cluster_assignments.shape

    if coeff_n_examples != pei_n_examples:
        raise ValueError

    component_indices = _get_component_indices(n_components)

    top_examples_reader = top_examples_from_clusters.TopExamplesReaderFromClusters.create(
        cluster_assignments=cluster_assignments,
        centroid_distances=centroid_distances,
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )

    latex_generator = LatexGenerator.create(
        top_examples_reader=top_examples_reader,
        tokenizer=tokenizer,
        n_top_examples=FLAGS.n_top_examples,
        components_fontsize=FLAGS.components_fontsize,
        **_read_flag_kwargs(FLAGS.latex_generator_kwargs),
    )

    latex_content = latex_generator.generate_components_latex(component_indices)

    with open(os.path.expanduser(FLAGS.output_filepath), 'wt') as f:
        f.write(latex_content)


if __name__ == "__main__":
    app.run(main)
