R"""Computes the ratio of the average PEF norm of the top-k component examples to data set average.
"""
import os

from absl import app
from absl import flags
import h5py
import numpy as np
from tqdm import tqdm

from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import hdf5_util

FLAGS = flags.FLAGS

flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")

flags.DEFINE_list("top_ks", None, "")


# def load_fisher_norms(pef_path):
#     pef = per_example.PerExampleFlatFishers.load(
#         pef_path,
#         n_examples=None,
#         # This leads to the Fishers not being loaded, which ends up being much faster.
#         start_fisher_index=0,
#         end_fisher_index=0,
#     )
#     return pef.dense_fisher_norms

def load_fisher_norms(pef_path):
    with h5py.File(pef_path, "r") as f:
        if 'data/pef_frobenius_norms' in f:
            return hdf5_util.load_h5_ds(f['data/pef_frobenius_norms'])
        elif 'data/dense_fisher_norms' in f:
            return hdf5_util.load_h5_ds(f['data/dense_fisher_norms'])
        else:
            raise ValueError('No PEF norms found in PEFs file.')


# def load_W(nmf_path):
#     nmf = nmf_common.SparseNmfDecomposition.load(nmf_path)
#     nmf.normalize_components_to_unit_norm()
#     return nmf.W


def load_W(nmf_path):
    with h5py.File(nmf_path, "r") as f:
        return hdf5_util.load_h5_ds(f['data/W'])


def compute_ratios(W, norms, top_ks):
    norms = norms[:W.shape[0]]
    avg_norm = np.mean(norms)
    ret = [['Component', *top_ks]]
    for i in tqdm(range(W.shape[-1])):
        sorted_inds = np.argsort(-W[:, i])
        row = [i]
        for k in top_ks:
            row.append(np.mean(norms[sorted_inds[:k]]) / avg_norm)
        ret.append(row)
    return ret


def main(_):
    top_ks = [int(k) for k in FLAGS.top_ks]

    norms = load_fisher_norms(os.path.expanduser(FLAGS.pef_path))
    W = load_W(os.path.expanduser(FLAGS.nmf_path))

    rows = compute_ratios(W, norms, top_ks)
    for r in rows:
        print(', '.join([str(c) for c in r]))


if __name__ == "__main__":
    app.run(main)
