"""Computes the ratio of Fisher norms of component top examples to a baseline set of examples.

The PEFs should be PEFs, but the decomposition should be kmeans.
"""
import json
import os
from typing import List, Optional, Sequence, Tuple

from absl import app
from absl import flags

import h5py
import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_string('kmeans_filepath', None, 'The kmeans decomposition.')

flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the PEF norms.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_list('component_indices', None, 
                  'Leave set to None to run on all components. Does not affect semi-orthogonalization.')
flags.DEFINE_integer('max_non_empty_components', None,
                     'Leave None to run on all selected components. If provided, then runs perturbation experiments '
                     'for at most this many non-empty components.')

flags.DEFINE_integer('n_top_examples', None,
                     'The number of top examples for each component to evaluate on.')
flags.DEFINE_integer('min_n_top_examples', 1,
                     'Examples with a coefficient of zero will be excluded from the top examples. Hence, do '
                     'run evaulations for components with fewer than this number of examples with non-zero '
                     'coefficients.')
# flags.DEFINE_integer('n_baseline_examples', None,
#                      'The number of baseline examples to evaluate on.')


###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _get_component_indices(n_components: int) -> Sequence[int]:
    if FLAGS.component_indices:
        ret = [int(ci) for ci in FLAGS.component_indices]
        assert all(0 <= ci < n_components for ci in ret)
        return ret
    else:
        return range(n_components)


def _read_pef_norms_from_single_file(filepath: str, n_examples: Optional[int]) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        ret = hdf5_utils.load_h5_ds(f['data/pef_frobenius_norms'])

    if n_examples is not None:
        ret = ret[:n_examples]

    return torch.from_numpy(ret)


def _read_all_pef_norms() -> torch.Tensor:
    pef_filepaths = FLAGS.pef_filepaths

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    if n_examples_per_pef is not None:
        assert len(n_examples_per_pef) == len(pef_filepaths)

    ret = []
    for i, filepath in enumerate(pef_filepaths):
        n_examples = None if n_examples_per_pef is None else n_examples_per_pef[i]
        ret.append(_read_pef_norms_from_single_file(filepath, n_examples))

    return torch.cat(ret, dim=0)


###############################################################################


def _get_top_example_indices(decomp, component_index: int) -> torch.Tensor:
    # Hopefully 1e12 is big enough. If the gradients were normalized beforehand, they should all lie
    # on the unit spehere, and all centroids should lie within the unit ball. Hence all distances should
    # be less than 2, so 1e12 should be more than big enough.
    big_dist = 1e12

    centroid_distances = decomp.centroid_distances.clone()
    centroid_distances[decomp.cluster_assignments != component_index] = big_dist

    top_coeffs, top_indices = torch.topk(centroid_distances, k=FLAGS.n_top_examples, largest=False)
    return top_indices[top_coeffs < big_dist]


def _compute_top_examples_mean_norm(*, decomp, pef_norms: torch.Tensor, component_index: int) -> Optional[float]:
    examples_indices = _get_top_example_indices(decomp, component_index)

    if len(examples_indices) < FLAGS.min_n_top_examples:
        return None

    return float(torch.mean(pef_norms[examples_indices]).detach().cpu().numpy())


###############################################################################


@torch.no_grad()
def main(_):
    decomp = kmeans.KmeansClusteringTorch.load(FLAGS.kmeans_filepath)

    pef_norms = _read_all_pef_norms()

    assert decomp.n_examples == pef_norms.shape[0]

    baseline_examples_mean_norm = float(torch.mean(pef_norms).detach().cpu().numpy())

    results = []
    for component_index in _get_component_indices(decomp.n_components):
        # Early exit if --max_non_empty_components is set.
        if FLAGS.max_non_empty_components is not None and len(results) >= FLAGS.max_non_empty_components:
            break

        top_examples_mean_norm = _compute_top_examples_mean_norm(decomp=decomp, pef_norms=pef_norms, component_index=component_index)
        if top_examples_mean_norm is None:
            continue

        results.append({
            'top_examples_mean_norm': top_examples_mean_norm,
            'baseline_examples_mean_norm': baseline_examples_mean_norm,
        })

    with open(os.path.expanduser(FLAGS.output_filepath), 'wt') as f:
        json.dump(results, f)


if __name__ == "__main__":
    app.run(main)
