import os
import shutil
import sys

def cleanup_checkpoints(parent_dir):
    # Read the latest checkpoint number
    latest_file = os.path.join(parent_dir, "latest_checkpointed_iteration.txt")
    if not os.path.isfile(latest_file):
        print(f"Error: {latest_file} not found.")
        return

    with open(latest_file, "r") as f:
        latest_num = f.read().strip()

    latest_checkpoint_dir = f"global_step_{latest_num}"
    print(f"Keeping checkpoint: {latest_checkpoint_dir}")

    # Iterate through all subdirectories
    for name in os.listdir(parent_dir):
        path = os.path.join(parent_dir, name)
        if os.path.isdir(path) and name.startswith("global_step_"):
            if name != latest_checkpoint_dir:
                print(f"Removing {path}")
                shutil.rmtree(path)

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print(f"Usage: {sys.argv[0]} <parent_directory>")
        sys.exit(1)

    cleanup_checkpoints(sys.argv[1])
