"""

"""
import os
from typing import Dict, Set

from absl import app
from absl import flags

import h5py
import torch

from npeff_torch.util import hdf5_utils


###############################################################################
FLAGS = flags.FLAGS

flags.DEFINE_string('smaller_npeff_filepath', None, 'NPEFF decomposition with fewer components.')
flags.DEFINE_string('larger_npeff_filepath', None, 'NPEFF decomposition with more components.')

###############################################################################


def _read_and_normalize_coeffs(filepath: str, device: torch.device) -> torch.Tensor:
    # Result will have each [n_examples] vector associated to a component have unit l2 norm.
    # ret.shape = [n_examples, n_components]
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    W = torch.from_numpy(W).to(device)
    W = torch.nn.functional.normalize(W, dim=0)
    return W


def _compute_matches(*, smaller_W: torch.Tensor, larger_W: torch.Tensor) -> Dict[int, Set[int]]:
    # Returns a map from the smaller component index to the set of larger component indices that
    # it matches with. Every smaller component index will have an entry; it is possible for values
    # to be the empty set.
    assert smaller_W.shape[1] < larger_W.shape[1]
    
    # cs.shape = [smaller_n_components, larger_n_components]
    cs = torch.einsum('ec,ek->ck', smaller_W, larger_W)

    # shape = [larger_n_components]
    most_similar_smaller_component_index = torch.argmax(cs, dim=0).detach().cpu().numpy()

    matches = {i: set() for i in range(smaller_W.shape[1])}
    for larger_component_index, smaller_component_index in enumerate(most_similar_smaller_component_index):
        smaller_component_index = int(smaller_component_index)
        matches[smaller_component_index].add(larger_component_index)

    return matches


def _print_match_statistics(matches: Dict[int, Set[int]]):
    n_zero_matches = len([v for v in matches.values() if len(v) == 0])
    n_one_matches = len([v for v in matches.values() if len(v) == 1])
    n_multi_matches = len([v for v in matches.values() if len(v) > 1])
    print(f'n_zero_matches: {n_zero_matches}')
    print(f'n_one_matches: {n_one_matches}')
    print(f'n_multi_matches: {n_multi_matches}')


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    smaller_W = _read_and_normalize_coeffs(FLAGS.smaller_npeff_filepath, device)
    larger_W = _read_and_normalize_coeffs(FLAGS.larger_npeff_filepath, device)
    assert smaller_W.shape[1] < larger_W.shape[1]

    matches = _compute_matches(smaller_W=smaller_W, larger_W=larger_W)
    _print_match_statistics(matches)


if __name__ == "__main__":
    app.run(main)
