"""

"""
import collections
import os
from typing import Dict, List, Optional, Set, Union

from absl import app
from absl import flags

import h5py
import numpy as np
import torch

from npeff_torch.examination.top_examples import top_examples_common
from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils


###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('npeff_filepath_1', None, 'NPEFF decomposition with fewer components.')
flags.DEFINE_string('npeff_filepath_2', None, 'NPEFF decomposition with more components.')

flags.DEFINE_list('pef_filepaths_1', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_1', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")

flags.DEFINE_list('pef_filepaths_2', None, 'The PEF files used to compute the NPEFF coefficients.')
flags.DEFINE_list('n_examples_per_pef_2', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_float('tuning_fraction', None, '')

flags.DEFINE_integer('min_top_examples', 1, '')

###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_and_normalize_coeffs(filepath: str) -> np.ndarray:
    # Result will have each [n_examples] vector associated to a component have unit l2 norm.
    # ret.shape = [n_examples, n_components]
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    W = torch.from_numpy(W)
    W = torch.nn.functional.normalize(W, dim=0).detach().cpu().numpy()
    return W


def _make_reader(
    npeff_filepath: str,
    pef_filepaths: List[str],
    n_examples_per_pef: List[str],
):
    n_examples_per_pef = _read_n_examples_per_pef_flag(n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(pef_filepaths, n_examples_per_pef)
    pei_n_examples = pef_extra_infos.n_examples
    assert pef_extra_infos.examples is not None
    assert pei_n_examples is not None

    coefficients = _read_and_normalize_coeffs(npeff_filepath)
    coeff_n_examples, n_components = coefficients.shape

    if coeff_n_examples != pei_n_examples:
        raise ValueError

    return top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
        coefficients=coefficients,
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )

###############################################################################


def _compute_matches(
    reader_1: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    reader_2: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    device: torch.device
) -> Dict[int, Set[int]]:
    # Returns a map from the smaller component index to the set of larger component indices that
    # it matches with. Every smaller component index will have an entry; it is possible for values
    # to be the empty set.
    smaller_W = torch.from_numpy(reader_1.coefficients).to(device)
    larger_W = torch.from_numpy(reader_2.coefficients).to(device)
    assert smaller_W.shape[1] < larger_W.shape[1]
    
    # cs.shape = [smaller_n_components, larger_n_components]
    cs = torch.einsum('ec,ek->ck', smaller_W, larger_W)

    # shape = [larger_n_components]
    most_similar_smaller_component_index = torch.argmax(cs, dim=0).detach().cpu().numpy()

    matches = {i: set() for i in range(smaller_W.shape[1])}
    for larger_component_index, smaller_component_index in enumerate(most_similar_smaller_component_index):
        smaller_component_index = int(smaller_component_index)
        matches[smaller_component_index].add(larger_component_index)

    return matches


def _get_tuning(x: List[int]) -> Union[int, None]:
    # A None indicates no tuning, otherwise returns what the component is tuned for.
    counter = collections.defaultdict(lambda: 0)
    for y in x:
        counter[y] += 1

    most_seen_value, most_seen_count = max(counter.items(), key=lambda z: z[1])
    if most_seen_count >= FLAGS.tuning_fraction * len(x):
        return most_seen_value
    else:
        return None


def _compute_prediction_tunings(
    reader: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
) -> Dict[int, Union[int, None]]:
    # Returns a dict with a key for each component with more than the minimum number of top examples. The
    # corresponding value is None if there is no prediction that it is tuned for. Otherwise, it is an int
    # indicating the class that the component is tuned for.

    ret = {}

    for component_index in range(reader.n_components):
        top_examples = reader.get_top_examples_for_component(component_index, FLAGS.n_top_examples)
        if len(top_examples) < FLAGS.min_top_examples:
            continue

        predictions = [e.get_prediction() for e in top_examples]
        assert all(x is not None for x in predictions)

        ret[component_index] = _get_tuning(predictions)

    return ret


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    reader_1 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_1,
        pef_filepaths=FLAGS.pef_filepaths_1,
        n_examples_per_pef=FLAGS.n_examples_per_pef_1,
    )
    reader_2 = _make_reader(
        npeff_filepath=FLAGS.npeff_filepath_2,
        pef_filepaths=FLAGS.pef_filepaths_2,
        n_examples_per_pef=FLAGS.n_examples_per_pef_2,
    )
    assert reader_1.n_components < reader_2.n_components

    prediction_tunings_1 = _compute_prediction_tunings(reader_1)
    assert len(prediction_tunings_1) == reader_1.n_components

    prediction_tunings_2 = _compute_prediction_tunings(reader_2)
    assert len(prediction_tunings_2) == reader_2.n_components

    matches = _compute_matches(reader_1, reader_2, device)

    #

    tuned_1_zero_match_count = 0

    tuned_1_single_match_count = 0
    tuned_1_single_match_matching_tunings_count = 0

    tuned_1_multi_match_count = 0
    tuned_1_multi_match_matches_count = 0
    tuned_1_multi_match_matching_tunings_count = 0
    #

    for component_index_1 in range(reader_1.n_components):
        component_matches = matches[component_index_1]

        component_tuning_1 = prediction_tunings_1[component_index_1]
        component_tunings_2 = [prediction_tunings_2[component_index_2] for component_index_2 in component_matches]

        # Only look at cases where the component is tuned.
        if component_tuning_1 is None:
            continue

        n_matching_component_tunings = sum(t2 == component_tuning_1 for t2 in component_tunings_2)

        if len(component_matches) == 0:
            tuned_1_zero_match_count += 1

        elif len(component_matches) == 1:
            tuned_1_single_match_count += 1
            tuned_1_single_match_matching_tunings_count += n_matching_component_tunings

        else:
            tuned_1_multi_match_count += 1
            tuned_1_multi_match_matches_count += len(component_matches)
            tuned_1_multi_match_matching_tunings_count += n_matching_component_tunings

    #

    print(f'tuned_1_zero_match_count: {tuned_1_zero_match_count}')
    print('')
    print(f'tuned_1_single_match_count: {tuned_1_single_match_count}')
    print(f'tuned_1_single_match_matching_tunings_count: {tuned_1_single_match_matching_tunings_count}')
    print('')
    print(f'tuned_1_multi_match_count: {tuned_1_multi_match_count}')
    print(f'tuned_1_multi_match_matches_count: {tuned_1_multi_match_matches_count}')
    print(f'tuned_1_multi_match_matching_tunings_count: {tuned_1_multi_match_matching_tunings_count}')
    print('')


if __name__ == "__main__":
    app.run(main)
