"""


"""
import json
import os
import pydoc
import random
from typing import List, Optional

from absl import app
from absl import flags

import h5py
import numpy as np

from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.examination.top_examples.human_evals import humev_common
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath_prefix', None, 
                    'The directory the files will be written to must already exist.')

# These two should have the same components but with coefficients computed on disjoint groups of examples.
flags.DEFINE_string('npeff_filepath_1', None, '')
flags.DEFINE_string('npeff_filepath_2', None, '')

flags.DEFINE_list('pef_filepaths_1', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_1', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")

flags.DEFINE_list('pef_filepaths_2', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_2', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_string('latex_generator_cls_path', None,
                    'Should be subclass of npeff_torch.examination.top_examples.human_evals.humev_top_examples_same_theme_latex.HumevTopExamplesSameThemeLatexGeneratorAbc')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of generators will require some flags to be passed in this manner.
flags.DEFINE_string('latex_generator_kwargs', None, '')


flags.DEFINE_integer('rng_seed', None, '')


flags.DEFINE_integer('example_group_size', None, '')

flags.DEFINE_integer('n_same_component_groups', None, '')
flags.DEFINE_integer('n_different_component_groups', None, '')

flags.DEFINE_bool('shuffle_top_examples', True, '')
flags.DEFINE_bool('shuffle_pair_order', True, '')

flags.DEFINE_bool('unique_top_examples', True, '')


flags.DEFINE_integer('n_documents', None, 'Number of human evaluation documents to create.')


###############################################################################


def _check_output_directory(output_filepath_prefix: str):
    if not os.path.isdir(os.path.dirname(output_filepath_prefix)):
        raise ValueError('The directory associated with --output_filepath_prefix must exist.')


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_coeffs(filepath: str, eps=1e-12) -> np.ndarray:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    return W


def _make_reader(
    npeff_filepath: str,
    pef_filepaths: List[str],
    n_examples_per_pef: List[str],
):
    n_examples_per_pef = _read_n_examples_per_pef_flag(n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(pef_filepaths, n_examples_per_pef)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    coefficients = _read_coeffs(npeff_filepath)
    coeff_n_examples, n_components = coefficients.shape

    if coeff_n_examples != pei_n_examples:
        raise ValueError

    return top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
        coefficients=coefficients,
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )


###############################################################################
_CSV_HEADERS = ['Pair', 'Shared Theme [yes/maybe/no]', 'Description']
###############################################################################


def main(_):
    assert FLAGS.rng_seed is not None
    rng = random.Random(FLAGS.rng_seed)
    _check_output_directory(FLAGS.output_filepath_prefix)

    LatexGenerator = pydoc.locate(FLAGS.latex_generator_cls_path)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    reader_1 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_1,
        pef_filepaths=FLAGS.pef_filepaths_1,
        n_examples_per_pef=FLAGS.n_examples_per_pef_1,
    )
    reader_2 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_2,
        pef_filepaths=FLAGS.pef_filepaths_2,
        n_examples_per_pef=FLAGS.n_examples_per_pef_2,
    )
    assert reader_1.n_components == reader_2.n_components

    latex_generator = LatexGenerator.create(
        tokenizer=tokenizer,
        **_read_flag_kwargs(FLAGS.latex_generator_kwargs),
    )

    for document_index in range(FLAGS.n_documents):
        n_same_component_groups = FLAGS.n_same_component_groups
        n_different_component_groups = FLAGS.n_different_component_groups

        same_component_example_group_pairs = humev_common.make_same_component_example_group_pairs(
            rng=rng,
            reader_1=reader_1,
            reader_2=reader_2,
            n_groups=n_same_component_groups,
            n_examples_per_group=FLAGS.example_group_size,
            shuffle_examples_in_group=FLAGS.shuffle_top_examples,
            shuffle_pair_order=FLAGS.shuffle_pair_order,
            unique_top_examples=FLAGS.unique_top_examples,
        )
        different_component_example_group_pairs = humev_common.make_different_component_example_group_pairs(
            rng=rng,
            reader_1=reader_1,
            reader_2=reader_2,
            n_groups=n_different_component_groups,
            n_examples_per_group=FLAGS.example_group_size,
            shuffle_examples_in_group=FLAGS.shuffle_top_examples,
            shuffle_pair_order=FLAGS.shuffle_pair_order,
            unique_top_examples=FLAGS.unique_top_examples,
        )

        #

        all_pair_indices = list(range(n_same_component_groups + n_different_component_groups))
        rng.shuffle(all_pair_indices)

        all_pairs = [*same_component_example_group_pairs, *different_component_example_group_pairs]
        all_pairs = [all_pairs[i] for i in all_pair_indices]

        # Has 1 if it the pair comes from the same component has a 0 if the pair has different components.
        pair_type_indicator = [i < n_same_component_groups for i in all_pair_indices]

        # 

        latex_content = latex_generator.generate_latex(all_pairs)
        csv_content = humev_common.make_evaluation_csv(_CSV_HEADERS, len(pair_type_indicator))

        # 

        latex_filepath = f'{FLAGS.output_filepath_prefix}.{document_index}.tex'
        csv_filepath = f'{FLAGS.output_filepath_prefix}.evaluation.{document_index}.csv'
        pair_type_indicator_filepath = f'{FLAGS.output_filepath_prefix}.labels.{document_index}.json'

        with open(os.path.expanduser(latex_filepath), 'wt') as f:
            f.write(latex_content)

        with open(os.path.expanduser(csv_filepath), 'wt') as f:
            f.write(csv_content)

        with open(os.path.expanduser(pair_type_indicator_filepath), 'wt') as f:
            f.write(json.dumps(pair_type_indicator))


if __name__ == "__main__":
    app.run(main)
