"""
Script to sample the hardest/easiest neurons across the network that should be
chosen available for experiments. The units are written to a json-file, which can then
be read by get_optimized_stimuli.py and extract_exemplars.py.
"""

import json
from argparse import ArgumentParser

import pandas as pd


def main(args):
    """
    Samples n_units units for model_name.

    :param args: the CLI arguments
    """

    df = pd.read_pickle(args.machine_interpretability_fn)
    df = df.sort_values(
        by=args.machine_interpretability_key, ascending=args.mode == "ascending"
    )

    print(df[args.machine_interpretability_key].iloc[: args.n_units].mean())
    return

    assert args.n_units <= len(df), "Not enough units available for sampling."

    units = []
    for i in range(args.n_units):
        if args.verbose:
            print(
                f"{i+1}. {df.iloc[i].layer}:{df.iloc[i].unit} ({df.iloc[i][args.machine_interpretability_key]})"
            )
        layer = df.iloc[i]["layer"]
        unit = df.iloc[i]["unit"]
        unit_name = f"{layer}__{unit}"
        units.append(unit_name)

    # store the chosen units
    data = {"units": units}
    with open(args.filename, "w", encoding="utf-8") as f:
        json.dump(data, f)

    if args.verbose:
        print(f"Sampled {args.n_units} neurons.")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--n_units", type=int, required=True, help="How many units to sample"
    )
    parser.add_argument("--mode", choices=["ascending", "descending"], required=True)
    parser.add_argument(
        "--machine_interpretability_fn",
        type=str,
        required=True,
        help="Path to the machine interpretability DataFrame.",
    )
    parser.add_argument(
        "--filename",
        type=str,
        required=True,
        help="Filename of the json file to which results are written.",
    )
    parser.add_argument(
        "--machine_interpretability_key",
        required=True,
        help="Which key/column name to use for machine interpretability.",
    )
    parser.add_argument(
        "--verbose", action="store_true", help="Print more information."
    )

    arguments = parser.parse_args()

    main(arguments)
