#!/usr/bin/env python
import argparse
import os
import subprocess
from pathlib import Path
import sys

def main():
    p = argparse.ArgumentParser(
        description="Wrapper to evaluate a trained BenchMarl run."
    )
    p.add_argument(
        "--folder",
        required=True,
        help="Name of the parent folder containing a trained run (e.g. mappo_1).",
    )

    p.add_argument(
        "--timesteps",
        type=int,
        default=1500,
        help="Episode length used *only* for evaluation (defaults to 1500).",
    )
    p.add_argument(
        "--runs",
        type=int,
        default=1,
        help="How many independent evaluation roll‑outs to perform.",
    )
    p.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Base seed ‑ each run adds +i to it.",
    )

    args = p.parse_args()

    base = Path(args.folder)
    if not base.is_dir():
        print(f"[ERROR] Folder not found: {base}", file=sys.stderr)
        sys.exit(1)

    # 1) Derive algorithm from folder name
    algo = base.name.split("_", 1)[0]

    # 2) Find the nested scenario folder: algo_navigation_*
    scen_dirs = [
        d for d in base.iterdir()
        if d.is_dir() and d.name.startswith(f"{algo}_navigation_")
    ]
    if not scen_dirs:
        print(f"[ERROR] No subfolder starting with '{algo}_navigation_' found in {base}", file=sys.stderr)
        sys.exit(1)
    scen = scen_dirs[0]

    # 3) Locate the .pt checkpoint inside scen/checkpoints/
    ckpt_dir = scen / "checkpoints"
    if not ckpt_dir.is_dir():
        print(f"[ERROR] No 'checkpoints' dir in {scen}", file=sys.stderr)
        sys.exit(1)
    ckpts = list(ckpt_dir.glob("*.pt"))
    if not ckpts:
        print(f"[ERROR] No .pt files in {ckpt_dir}", file=sys.stderr)
        sys.exit(1)
    # pick the latest by name (or sort by mtime if you prefer)
    ckpt = sorted(ckpts)[-1]

    # 4‑6) Loop over the requested number of evaluation runs
    for i in range(args.runs):
        run_out = Path("eval") / base.name / f"run_{i+1}"
        run_out.mkdir(parents=True, exist_ok=True)

        # Tell SmartGridScenario where to write the CSVs
        os.environ["SMARTGRID_LOG_DIR"] = str(run_out)
        os.environ["BM_ALGO"] = algo

        seed_i = args.seed + i

        cmd = [
            sys.executable, "run.py",
            "--algorithm", algo,
            "--task", "customenv",
            "--mode", "eval",
            "--checkpoint", str(ckpt),
            "--episode-length", str(args.timesteps),
            "--seed",           str(seed_i),
        ]
        label = f"[{i+1}/{args.runs}]"
        print(label, "Running:", " ".join(cmd))
        subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()
