"""

"""
import json
import os
import pydoc
import random
from typing import List, Optional

from absl import app
from absl import flags

import h5py
import numpy as np

from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.examination.top_examples.human_evals import humev_common
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
FLAGS = flags.FLAGS

flags.DEFINE_string('output_filepath_prefix', None, 
                    'The directory the files will be written to must already exist.')

flags.DEFINE_list('pef_filepaths', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('npeff_filepath', None, '')

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_string('latex_generator_cls_path', None,
                    'Should be subclass of npeff_torch.examination.top_examples.human_evals.humev_top_examples_theme_latex.HumevTopExamplesThemeLatexGeneratorAbc')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of generators will require some flags to be passed in this manner.
flags.DEFINE_string('latex_generator_kwargs', None, '')


flags.DEFINE_integer('rng_seed', None, '')


flags.DEFINE_integer('example_group_size', None, '')

flags.DEFINE_integer('n_top_example_groups', None, '')
flags.DEFINE_integer('n_random_example_groups', None, '')

flags.DEFINE_bool('shuffle_top_examples', True, '')

flags.DEFINE_bool('unique_top_examples', True, '')


flags.DEFINE_integer('n_documents', None, 'Number of human evaluation documents to create.')


###############################################################################

def _check_output_directory(output_filepath_prefix: str):
    if not os.path.isdir(os.path.dirname(output_filepath_prefix)):
        raise ValueError('The directory associated with --output_filepath_prefix must exist.')


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_coeffs(filepath: str, eps=1e-12) -> np.ndarray:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    return W


###############################################################################
_CSV_HEADERS = ['Group', 'Themed [yes/maybe/no]', 'Description']
###############################################################################


def main(_):
    assert FLAGS.rng_seed is not None
    rng = random.Random(FLAGS.rng_seed)
    _check_output_directory(FLAGS.output_filepath_prefix)

    LatexGenerator = pydoc.locate(FLAGS.latex_generator_cls_path)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    coefficients = _read_coeffs(FLAGS.npeff_filepath)
    coeff_n_examples, n_components = coefficients.shape

    if coeff_n_examples != pei_n_examples:
        raise ValueError

    reader = top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
        coefficients=coefficients,
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )

    latex_generator = LatexGenerator.create(
        tokenizer=tokenizer,
        **_read_flag_kwargs(FLAGS.latex_generator_kwargs),
    )

    for document_index in range(FLAGS.n_documents):
        n_top_example_groups = FLAGS.n_top_example_groups
        n_random_example_groups = FLAGS.n_random_example_groups

        component_indices, top_example_groups = humev_common.make_top_example_groups(
            rng=rng,
            reader=reader,
            n_groups=n_top_example_groups,
            n_examples_per_group=FLAGS.example_group_size,
            shuffle_examples_in_group=FLAGS.shuffle_top_examples,
            unique_top_examples=FLAGS.unique_top_examples,
        )
        random_example_groups = humev_common.make_random_example_groups(
            rng=rng,
            reader=reader,
            n_groups=n_random_example_groups,
            n_examples_per_group=FLAGS.example_group_size,
        )

        # 

        all_group_indices = list(range(n_top_example_groups + n_random_example_groups))
        rng.shuffle(all_group_indices)

        all_groups = [*top_example_groups, *random_example_groups]
        all_groups = [all_groups[i] for i in all_group_indices]

        # Has 1 if it is a top example group, has a 0 if it is a random group.
        group_type_indicator = [i < n_top_example_groups for i in all_group_indices]

        # 

        latex_content = latex_generator.generate_latex(all_groups)
        csv_content = humev_common.make_evaluation_csv(_CSV_HEADERS, len(group_type_indicator))

        # 

        latex_filepath = f'{FLAGS.output_filepath_prefix}.{document_index}.tex'
        csv_filepath = f'{FLAGS.output_filepath_prefix}.evaluation.{document_index}.csv'
        group_type_indicator_filepath = f'{FLAGS.output_filepath_prefix}.labels.{document_index}.json'

        with open(os.path.expanduser(latex_filepath), 'wt') as f:
            f.write(latex_content)

        with open(os.path.expanduser(csv_filepath), 'wt') as f:
            f.write(csv_content)

        with open(os.path.expanduser(group_type_indicator_filepath), 'wt') as f:
            f.write(json.dumps(group_type_indicator))


if __name__ == "__main__":
    app.run(main)
