R"""Filters a saved set of PEFs to enrich the concentration of a subtype of examples.



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 scripts1/data_gen/enrich_pefs.py \
    --pef_path="/home/owner/Desktop/projects_data/extract_merge1/feather_berts_0.hans_lone.no_embeddings.5k.32k.h5" \
    --example_subtype=incorrect_prediction \
    --desired_fraction=0.5 \
    --output_path=/tmp/asdf.h5


"""
import os

from absl import app
from absl import flags
import numpy as np

from em.fishers import per_example

FLAGS = flags.FLAGS


_EXAMPLE_SUBTYPES = ['incorrect_prediction']

flags.DEFINE_string("pef_path", None, "")

flags.DEFINE_enum('example_subtype', None, _EXAMPLE_SUBTYPES, '')
flags.DEFINE_float("desired_fraction", None, "")

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")


def get_subtype_mask(pef):
    subtype = FLAGS.example_subtype
    if subtype == 'incorrect_prediction':
        # TODO: I think I have to shift around logits/labels for some model/dataset
        # pairs, so support that as needed.
        preds = np.argmax(pef.predicted_logits, axis=-1)
        return preds != pef.labels
    else:
        return ValueError(subtype)


def main(_):
    print('Loading PEF...')
    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(FLAGS.pef_path),
        normalize_fishers=False,
    )
    print('Finished loading PEF.')

    subtype_mask = get_subtype_mask(pef)
    complement_mask = ~subtype_mask

    subtype_inds, = np.nonzero(subtype_mask)
    complement_inds, = np.nonzero(complement_mask)

    n_total = subtype_mask.shape[0]
    n_subtype = subtype_inds.shape[0]

    new_total = int(n_subtype / FLAGS.desired_fraction)

    if new_total > n_total:
        # This shouldn't be an issue if the desired fraction is higher than the current fraction.
        raise ValueError('The desired fraction is too high; not enough examples exist to reach it.')

    all_inds = np.concatenate([subtype_inds, complement_inds[:new_total - n_subtype]], axis=0)
    np.random.shuffle(all_inds)

    out = pef.create_for_subset(all_inds)

    # print((out.labels != np.argmax(out.predicted_logits, axis=-1)).astype(np.float64).mean())

    print('Saving PEF...')
    out.save(FLAGS.output_path)
    print('Finished saving PEF.')


if __name__ == "__main__":
    app.run(main)
