import os
import json
import yaml
from pathlib import Path
import shutil
def main(args):

    # Check "run_batch_exit_statuses.yaml" in the patches folder
    run_batch_exit_statuses_path = Path(args.patches_folder) / "run_batch_exit_statuses.yaml"
    with open(run_batch_exit_statuses_path, "r") as f:
        run_batch_exit_statuses = yaml.safe_load(f)

    # Get the instances that are not " skipped (submitted)" or " skipped (submitted)"
    instances_to_rerun = []
    total_instances = 0
    for exit_status, instances in run_batch_exit_statuses["instances_by_exit_status"].items():
        if exit_status not in ["skipped (submitted)", " submitted"]:
            instances_to_rerun += instances
            print(f"Exit status:{exit_status}")
        total_instances += len(instances)
    # import ipdb; ipdb.set_trace()
    print(f"Number of instances to rerun: {len(instances_to_rerun)}")
    print(f"Total instances: {total_instances}")

    # Remove the instances from "{patch_foldeer}/{instance_id}"
    for instance in instances_to_rerun:
        instance_folder = Path(args.patches_folder) / instance
        if instance_folder.exists():
            shutil.rmtree(instance_folder)
            print(f"Removed {instance_folder}")
        else:
            print(f"Instance folder {instance_folder} does not exist")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Remove error instances from the training set")
    parser.add_argument("--patches_folder", required=True, help="Patches folder")
    args = parser.parse_args()
    main(args)