"""Meant for a section of the paper."""
import collections
import os
from typing import Dict, List, Optional, Union

from absl import app
from absl import flags

import h5py
import numpy as np
import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_from_clusters
from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils

###############################################################################
_DECOMPOSITION_TYPES = ['npeff', 'kmeans']
###############################################################################

FLAGS = flags.FLAGS


# These two should have the same components but with coefficients computed on disjoint groups of examples.
# The "*_1" version are considered to be the held-in examples, and the "*_2" version are considered to be
# the held-out examples.
flags.DEFINE_string('npeff_filepath_1', None, '')
flags.DEFINE_string('npeff_filepath_2', None, '')

flags.DEFINE_enum('decomposition_type', 'npeff', _DECOMPOSITION_TYPES, 'Hack to let us use this for kmeans.')

flags.DEFINE_list('pef_filepaths_1', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_1', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")

flags.DEFINE_list('pef_filepaths_2', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_2', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_float('tuning_fraction', None, '')

flags.DEFINE_integer('min_top_examples', 1, '')

###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_coeffs(filepath: str, eps=1e-12) -> np.ndarray:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    return W


def _make_reader(
    npeff_filepath: str,
    pef_filepaths: List[str],
    n_examples_per_pef: List[str],
):
    n_examples_per_pef = _read_n_examples_per_pef_flag(n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(pef_filepaths, n_examples_per_pef)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    common_kwargs = {
        'examples': pef_extra_infos.examples,
        'labels': pef_extra_infos.labels,
        'logits': pef_extra_infos.logits,
        'top_log_probs_class_indices': pef_extra_infos.top_log_probs_class_indices,
        'top_log_probs_values': pef_extra_infos.top_log_probs_values,
        'token_positions': pef_extra_infos.token_positions,
    }

    if FLAGS.decomposition_type == 'npeff':
        coefficients = _read_coeffs(npeff_filepath)
        coeff_n_examples, n_components = coefficients.shape

        if coeff_n_examples != pei_n_examples:
            raise ValueError

        return top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
            coefficients=coefficients,
            **common_kwargs,
        )

    elif FLAGS.decomposition_type == 'kmeans':
        cluster_assignments = kmeans.KmeansClusteringTorch.load_cluster_assignments(npeff_filepath)
        centroid_distances = kmeans.KmeansClusteringTorch.load_centroid_distances(npeff_filepath)
        n_components = kmeans.KmeansClusteringTorch.load_n_clusters(npeff_filepath)

        if not (cluster_assignments.numel() == centroid_distances.numel() == pei_n_examples):
            raise ValueError

        return top_examples_from_clusters.TopExamplesReaderFromClusters(
            cluster_assignments=cluster_assignments.detach().cpu().numpy(),
            centroid_distances=centroid_distances.detach().cpu().numpy(),
            n_components=n_components,
            **common_kwargs,
        )

    else:
        raise ValueError(FLAGS.decomposition_type)

###############################################################################


def _get_tuning(x: List[int]) -> Union[int, None]:
    # A None indicates no tuning, otherwise returns what the component is tuned for.
    counter = collections.defaultdict(lambda: 0)
    for y in x:
        counter[y] += 1

    most_seen_value, most_seen_count = max(counter.items(), key=lambda z: z[1])
    if most_seen_count >= FLAGS.tuning_fraction * len(x):
        return most_seen_value
    else:
        return None


# def _determine_tunings(top_examples: List['top_examples_common.TopExampleInfo']) -> Dict[str, Union[int, None]]:
#     # A None indicates no tuning, otherwise returns what the component is tuned for.
#     labels = [e.label for e in top_examples]
#     predictions = [e.get_prediction() for e in top_examples]
#     assert all(x is not None for x in labels) and all(x is not None for x in predictions)

#     return {
#         'label': _get_tuning(labels),
#         'prediction': _get_tuning(predictions),
#     }


def _determine_tunings(top_examples: List['top_examples_common.TopExampleInfo']) -> Dict[str, Union[int, None]]:
    # A None indicates no tuning, otherwise returns what the component is tuned for.
    labels = [e.label for e in top_examples]
    predictions = [e.get_prediction() for e in top_examples]

    ret = {
        'label': 0,
        'prediction': 0,
    }

    if all(x is not None for x in labels):
        ret['label'] = _get_tuning(labels)

    if all(x is not None for x in predictions):
        ret['prediction'] = _get_tuning(predictions)

    return ret


###############################################################################


@torch.no_grad()
def main(_):
    reader_1 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_1,
        pef_filepaths=FLAGS.pef_filepaths_1,
        n_examples_per_pef=FLAGS.n_examples_per_pef_1,
    )
    reader_2 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_2,
        pef_filepaths=FLAGS.pef_filepaths_2,
        n_examples_per_pef=FLAGS.n_examples_per_pef_2,
    )
    assert reader_1.n_components == reader_2.n_components
    n_components = reader_1.n_components

    infos = {
        'label': collections.defaultdict(lambda: 0),
        'prediction': collections.defaultdict(lambda: 0),
    }
    total_count = 0

    for component_index in range(n_components):
        top_examples_1 = reader_1.get_top_examples_for_component(component_index, FLAGS.n_top_examples)
        top_examples_2 = reader_2.get_top_examples_for_component(component_index, FLAGS.n_top_examples)

        if len(top_examples_1) < FLAGS.min_top_examples or len(top_examples_2) < FLAGS.min_top_examples:
            continue

        total_count += 1
        
        tunings_1 = _determine_tunings(top_examples_1)
        tunings_2 = _determine_tunings(top_examples_2)

        for key in infos.keys():
            info = infos[key]
            t1 = tunings_1[key]
            t2 = tunings_2[key]

            info['tuned_1'] += int(t1 is not None)
            info['tuned_2'] += int(t2 is not None)
            info['both_tuned'] += int(t1 is not None and t2 is not None)
            info['both_tuned_and_same'] += int(t1 is not None and t2 is not None and t1 == t2)

    print(f'total: {total_count}')
    for key in sorted(infos.keys()):
        for subkey in ['tuned_1', 'tuned_2', 'both_tuned', 'both_tuned_and_same']:
            print(f'{key} {subkey}: {infos[key][subkey]}')


if __name__ == "__main__":
    app.run(main)
