"""Computes the ratio of Fisher norms of component top examples to a baseline set of examples."""
import json
import os
from typing import List, Optional, Sequence, Tuple

from absl import app
from absl import flags

import h5py
import torch

from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_string('npeff_filepath', None, 'The LRM-NPEFF decomposition.')

flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the PEF norms.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_list('component_indices', None, 
                  'Leave set to None to run on all components. Does not affect semi-orthogonalization.')

flags.DEFINE_integer('n_top_examples', None,
                     'The number of top examples for each component to evaluate on.')
flags.DEFINE_integer('min_n_top_examples', 1,
                     'Examples with a coefficient of zero will be excluded from the top examples. Hence, do '
                     'run evaulations for components with fewer than this number of examples with non-zero '
                     'coefficients.')
# flags.DEFINE_integer('n_baseline_examples', None,
#                      'The number of baseline examples to evaluate on.')


###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _get_component_indices(n_components: int) -> Sequence[int]:
    if FLAGS.component_indices:
        ret = [int(ci) for ci in FLAGS.component_indices]
        assert all(0 <= ci < n_components for ci in ret)
        return ret
    else:
        return range(n_components)


def _read_pef_norms_from_single_file(filepath: str, n_examples: Optional[int]) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        ret = hdf5_utils.load_h5_ds(f['data/pef_frobenius_norms'])

    if n_examples is not None:
        ret = ret[:n_examples]

    return torch.from_numpy(ret)


def _read_all_pef_norms() -> torch.Tensor:
    pef_filepaths = FLAGS.pef_filepaths

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    if n_examples_per_pef is not None:
        assert len(n_examples_per_pef) == len(pef_filepaths)

    ret = []
    for i, filepath in enumerate(pef_filepaths):
        n_examples = None if n_examples_per_pef is None else n_examples_per_pef[i]
        ret.append(_read_pef_norms_from_single_file(filepath, n_examples))

    return torch.cat(ret, dim=0)


def _read_coeffs() -> torch.Tensor:
    with h5py.File(os.path.expanduser(FLAGS.npeff_filepath), "r") as f:
        return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/W']))


###############################################################################


def _compute_top_examples_mean_norm(*, W: torch.Tensor, pef_norms: torch.Tensor, component_index: int) -> Optional[float]:
    w = W[:, component_index]

    values, examples_indices = torch.topk(w, k=FLAGS.n_top_examples)
    examples_indices = examples_indices[values != 0.0]

    if len(examples_indices) < FLAGS.min_n_top_examples:
        return None

    return float(torch.mean(pef_norms[examples_indices]).detach().cpu().numpy())


###############################################################################


@torch.no_grad()
def main(_):
    # W.shape = [n_examples, n_components]
    W = _read_coeffs()
    # W.shape = [n_examples]
    pef_norms = _read_all_pef_norms()

    assert W.shape[0] == pef_norms.shape[0]

    baseline_examples_mean_norm = float(torch.mean(pef_norms).detach().cpu().numpy())

    results = []
    for component_index in _get_component_indices(W.shape[1]):
        top_examples_mean_norm = _compute_top_examples_mean_norm(W=W, pef_norms=pef_norms, component_index=component_index)
        results.append({
            'top_examples_mean_norm': top_examples_mean_norm,
            'baseline_examples_mean_norm': baseline_examples_mean_norm,
        })

    with open(os.path.expanduser(FLAGS.output_filepath), 'wt') as f:
        json.dump(results, f)


if __name__ == "__main__":
    app.run(main)
