"""Prints a subset of top example indices for NPEFF."""
import os

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.npeff import lrm_npeff_decomps


###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('npeff_filepath', None, '')

flags.DEFINE_list('component_indices', [], '')

flags.DEFINE_integer('n_top_examples', None, '')


###############################################################################

def _read_component_indices(decomp):
    if FLAGS.component_indices:
        component_indices = FLAGS.component_indices
    else:
        component_indices = range(decomp.n_components)
    return torch.tensor([int(i) for i in component_indices], dtype=torch.int64)


@torch.no_grad()
def main(_):
    decomp = lrm_npeff_decomps.LrmNpeffDecomposition.load(FLAGS.npeff_filepath, load_W=True, load_G=False)
    component_indices = _read_component_indices(decomp)

    # example_indices.shape = [n_top_examples, n_components]
    _, example_indices = torch.topk(decomp.W, k=FLAGS.n_top_examples, dim=0)
    example_indices = example_indices[:, component_indices].t()
    # print(example_indices.shape)

    example_indices = example_indices.reshape(-1).detach().cpu().numpy()

    s = ','.join(str(i) for i in example_indices)
    print(s)


if __name__ == "__main__":
    app.run(main)
