R"""Makes a list of files each containing auxillary information about the top examples.
"""
import os
from typing import Tuple

from absl import app
from absl import flags

import numpy as np
import tensorflow as tf

from em.datasets.imagenet import imagenet_classes
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.projects.imagenet import imagenet_components_context as imagenet_cc

ImageNetCC = imagenet_cc.ImageNetCC

FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")

flags.DEFINE_string("ds_split", None, "")

flags.DEFINE_integer("n_rows", None, "")
flags.DEFINE_integer("n_cols", None, "")

flags.DEFINE_integer("k_classes", 5, "")

# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")
flags.DEFINE_string("output_file_extension", 'txt', "")


def index_to_coords(index: int) -> Tuple[int, int]:
    n_cols = FLAGS.n_cols
    row = index // n_cols
    col = index % n_cols
    return row, col


def make_aux_info_for_component(cc: ImageNetCC, component_index: int):
    # Ordering in grid: L->R, T->B
    n_rows, n_cols = FLAGS.n_rows, FLAGS.n_cols
    n_examples = n_rows * n_cols

    labels = cc.get_top_example_labels(component_index, n_examples)
    label_names = imagenet_classes.labels_to_classes(labels)

    logits = cc.get_top_example_logits(component_index, n_examples)
    class_scores = imagenet_classes.logits_to_top_classes(logits, FLAGS.k_classes)

    rows = ['Index,Row,Column,Label,Preds...']
    for i in range(n_examples):
        row = [i, *index_to_coords(i), label_names[i]]
        for name, prob in class_scores[i]:
            row.extend([name, prob])
        rows.append(','.join([str(r) for r in row]))
    return '\n'.join(rows)


def get_filename(component_index: int, extension: str) -> str:
    return f'{FLAGS.output_prefix}comp{component_index:03d}.{extension}'


def write_aux_info_for_component(output_dir: str, cc: ImageNetCC, component_index: int):
    content = make_aux_info_for_component(cc, component_index)
    filepath = os.path.join(output_dir, get_filename(component_index, FLAGS.output_file_extension))
    with open(filepath, 'wt') as f:
        f.write(content)


def main(_):
    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    cc = ImageNetCC(
        pef_filepath=os.path.expanduser(FLAGS.pef_path),
        nmf_filepath=os.path.expanduser(FLAGS.nmf_path),
        ds_split=FLAGS.ds_split,
        include_images=False
    )
    for component_index in range(cc.n_components):
        write_aux_info_for_component(output_dir, cc, component_index)


if __name__ == "__main__":
    app.run(main)
