#!/usr/bin/env python3
"""Prune Chemprop run directories so each seed keeps only one dated run with a model checkpoint."""
from __future__ import annotations

import argparse
import shutil
from pathlib import Path
from typing import Iterable, List, Optional


def has_model_checkpoint(run_dir: Path) -> bool:
    """Return True if any model checkpoint exists within the run directory."""
    for candidate in run_dir.rglob("model.pt"):
        if candidate.is_file():
            return True
    return False


def pick_run_dir(run_dirs: Iterable[Path]) -> Optional[Path]:
    """Pick the preferred run directory among candidates with a model checkpoint."""
    candidates: List[Path] = [run for run in run_dirs if has_model_checkpoint(run)]
    if not candidates:
        return None
    return max(candidates)


def prune_seed(seed_dir: Path, apply: bool) -> None:
    run_dirs = sorted(p for p in seed_dir.iterdir() if p.is_dir())
    if len(run_dirs) <= 1:
        return

    keep_dir = pick_run_dir(run_dirs)
    if keep_dir is None:
        print(f"[warn] No checkpoint found under {seed_dir}; nothing removed.")
        return

    for run_dir in run_dirs:
        if run_dir == keep_dir:
            continue
        if apply:
            shutil.rmtree(run_dir)
            print(f"[del] Removed {run_dir}")
        else:
            print(f"[dry-run] Would remove {run_dir}")


def prune_dataset(dataset_dir: Path, apply: bool) -> None:
    seed_dirs = sorted(p for p in dataset_dir.iterdir() if p.is_dir())
    if not seed_dirs:
        return
    for seed_dir in seed_dirs:
        prune_seed(seed_dir, apply)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "root",
        type=Path,
        help="Root directory that contains dataset/seed/run hierarchy.",
    )
    parser.add_argument(
        "--apply",
        action="store_true",
        help="Actually delete directories (otherwise dry-run).",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    root = args.root.expanduser().resolve()
    for dataset_dir in sorted(p for p in root.iterdir() if p.is_dir()):
        prune_dataset(dataset_dir, args.apply)


if __name__ == "__main__":
    main()
