R"""Creates of copy of LRM-PEFs file containing only the wrong predictions.

"""
import os

from absl import app
from absl import flags

import numpy as np

from em.fishers import lrm_pefs

from em.util.color_util import cu

FLAGS = flags.FLAGS

flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("special_processing", None, "")

SPECIAL_PROCESSING_TYPES = ('HF_MNLI',)

###############################################################################


def get_kept_example_indices(og_pefs: lrm_pefs.SparseLrmPefs) -> np.ndarray:
    labels = og_pefs.labels
    predictions = np.argmax(og_pefs.logits, axis=-1)

    if FLAGS.special_processing == 'HF_MNLI':
        labels = (labels + 1) % 3

    wrong_pred_inds, = np.nonzero(labels != predictions)

    # Print this so that we have a sanity check that we selected the correct
    # special processing.
    correct_frac = 1 - wrong_pred_inds.shape[0] / labels.shape[0]
    print(cu.hlg(f"Fraction of correct predictions: {correct_frac}"))

    return wrong_pred_inds


def main(_):
    if FLAGS.special_processing is not None and FLAGS.special_processing not in SPECIAL_PROCESSING_TYPES:
        raise ValueError(f'Invalid special processing flag value: {FLAGS.special_processing}')

    og_pefs = lrm_pefs.SparseLrmPefs.load(os.path.expanduser(FLAGS.pef_path))
    keep_inds = get_kept_example_indices(og_pefs)
    print(cu.hlg(f"Creating LRM-PEFS with {keep_inds.shape[0]} examples."))

    filtered_pefs = og_pefs.create_for_subset(keep_inds)

    filtered_pefs.save(os.path.expanduser(FLAGS.output_path))


if __name__ == "__main__":
    app.run(main)
