"""Counts the number of top components with tunings based on properties of their top examples."""
import json
import os
import pydoc
from typing import List, Optional, Tuple

from absl import app
from absl import flags
import h5py
import numpy as np

from npeff_torch.examination.top_examples import top_examples_from_clusters
from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
_DECOMPOSITION_TYPES = ['npeff', 'kmeans']

###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('decomposition_filepath', None, 'Either an NPEFF or kmeans decomposition.')

flags.DEFINE_enum('decomposition_type', None, _DECOMPOSITION_TYPES, '')

# Works with gradients as well.
flags.DEFINE_list('pef_filepaths', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_integer('n_top_examples', None, 
                     'This number of top examples will be used to represent each component.')
flags.DEFINE_integer('min_top_examples', 1, 
                     'Components with fewer than this number of top examples will be considered empty.')

flags.DEFINE_string('tokenizer', None, '')


# Component filtering stuff.
#   NOTE: Unlike for generation of component top examples latex, having multiple filters here will analyze each
#   filter independently instead of creating a single filter that is the logical AND of them.
flags.DEFINE_list('component_filter_cls_paths', [],
                  'List of pydoc.locate-compatible class paths used to create component filters. The '
                  'classes should be subclasses of component_filtering.ComponentFilterAbc. If more than one '
                  'class path is provided, a filter will be created for each entry. We will perform analysis '
                  'for each filter independently.')
flags.DEFINE_string('component_filter_kwargs', '', 
                    'JSON list of JSON kwargs for the component filters. Must be provided if --component_filter_cls_paths '
                    'is provided. Is a parallel list to --component_filter_cls_paths, so must have the same number of entries.')


###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_pef_extra_infos():
    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert pef_extra_infos.examples is not None
    assert pef_extra_infos.n_examples is not None
    return pef_extra_infos


def _read_clusters(filepath: str) -> Tuple[np.ndarray, np.ndarray]:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        cluster_assignments = hdf5_utils.load_h5_ds(f['data/cluster_assignments'])
        centroid_distances = hdf5_utils.load_h5_ds(f['data/centroid_distances'])
        n_components = f['data/centroids'].shape[0]
    return cluster_assignments, centroid_distances, n_components


def _read_coeffs(filepath: str) -> np.ndarray:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    return W


def _get_top_examples_reader():
    pef_extra_infos = _read_pef_extra_infos()
    pei_n_examples = pef_extra_infos.n_examples

    common_kwargs = {
        'examples': pef_extra_infos.examples,
        'labels': pef_extra_infos.labels,
        'logits': pef_extra_infos.logits,
        'top_log_probs_class_indices': pef_extra_infos.top_log_probs_class_indices,
        'top_log_probs_values': pef_extra_infos.top_log_probs_values,
        'token_positions': pef_extra_infos.token_positions,
    }

    if FLAGS.decomposition_type == 'npeff':
        coefficients = _read_coeffs(FLAGS.decomposition_filepath)
        coeff_n_examples, n_components = coefficients.shape

        if coeff_n_examples != pei_n_examples:
            raise ValueError

        reader = top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
            coefficients=coefficients,
            **common_kwargs,
        )
        return reader, n_components

    elif FLAGS.decomposition_type == 'kmeans':
        cluster_assignments, centroid_distances, n_components = _read_clusters(FLAGS.decomposition_filepath)
        coeff_n_examples, = cluster_assignments.shape

        if coeff_n_examples != pei_n_examples:
            raise ValueError

        reader = top_examples_from_clusters.TopExamplesReaderFromClusters.create(
            cluster_assignments=cluster_assignments,
            centroid_distances=centroid_distances,
            **common_kwargs,
        )
        return reader, n_components

    else:
        raise ValueError(f'Invalid --decomposition_type: {FLAGS.decomposition_type}')


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_component_filters():
    if len(FLAGS.component_filter_cls_paths) == 0:
        return None

    component_filter_kwargs = json.loads(FLAGS.component_filter_kwargs)
    assert len(component_filter_kwargs) == len(FLAGS.component_filter_cls_paths)

    filters = []
    for component_filter_cls_path, kw in zip(FLAGS.component_filter_cls_paths, component_filter_kwargs):
        ComponentFilter = pydoc.locate(component_filter_cls_path)
        filters.append(ComponentFilter(**kw))

    return filters


def _count_components_matching_filter(tokenizer, top_examples_reader, n_components, component_filter):
    count = 0
    empty_component_count = 0
    for component_index in range(n_components):
        top_examples = top_examples_reader.get_top_examples_for_component(component_index, FLAGS.n_top_examples)

        # TODO: Figure out what to do with these.
        if len(top_examples) < FLAGS.min_top_examples:
            empty_component_count += 1
            continue

        passes = component_filter.does_component_pass(
            tokenizer=tokenizer,
            component_index=component_index,
            top_examples=top_examples,
        )
        if passes:
            count += 1
    return count, empty_component_count


###############################################################################


def main(_):
    assert FLAGS.n_top_examples > 0

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)
    top_examples_reader, n_components = _get_top_examples_reader()

    for component_filter in _get_component_filters():
        n_matching_components, empty_component_count = _count_components_matching_filter(tokenizer, top_examples_reader, n_components, component_filter)
        print(f'{component_filter.__class__.__name__}: {n_matching_components}')

    print(f'empty_component_count: {empty_component_count}')

    # TODO: Do something, print output somewhere?


if __name__ == "__main__":
    app.run(main)
