#!/usr/bin/env python3
"""
Benchmark the wall‑clock training time of a BenchMarl algorithm
for a set of building‑types configurations.

Usage
-----
python run_time_complex.py --algorithm mappo --extra "--total-episodes 100"
"""
import argparse, subprocess, time, csv, shlex, sys
from pathlib import Path

# ---------------------------------------------------------------------
# 1) Parse CLI
# ---------------------------------------------------------------------
cli = argparse.ArgumentParser()
cli.add_argument("--algorithm", default="mappo",
                 help="BenchMarl algorithm to benchmark (default: mappo)")
cli.add_argument("--task", default="customenv", help="Task YAML name")
cli.add_argument("--extra", default="",
                 help='Extra flags to pass through to run.py '
                      '(wrap in quotes, e.g. --extra "--total-episodes 100")')
args = cli.parse_args()

EXTRA_FLAGS = shlex.split(args.extra)
ALGO        = args.algorithm
TASK        = args.task
CSV_PATH    = Path(f"{ALGO}_time_complexity.csv")

# ---------------------------------------------------------------------
# 2) Building‑types grids to test
# ---------------------------------------------------------------------
building_sets = [
    [5, 1],
    [5, 1, 1, 5],
    [5, 1, 1, 1, 1, 1, 5],
    [5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5],
    [5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5],
    [5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5, 5, 1, 1, 1, 5, 1, 1, 5]
]

# ---------------------------------------------------------------------
# 3) Run training & time it
# ---------------------------------------------------------------------
CSV_PATH.exists() or CSV_PATH.write_text("building_types,duration_sec\n")

for idx, buildings in enumerate(building_sets, 1):
    print(f"\n[{idx}/{len(building_sets)}] buildings = {buildings}")

    cmd = [
        sys.executable, "run.py",
        "--algorithm", ALGO,
        "--task",      TASK,
        "--mode",      "train",
        "--building-types",  ",".join(map(str, buildings)),   # <- custom flag
        *EXTRA_FLAGS,
    ]
    print("→", " ".join(cmd))

    start = time.perf_counter()
    subprocess.run(cmd, check=True)
    duration = time.perf_counter() - start

    print(f"   finished in {duration:,.1f} s")

    with CSV_PATH.open("a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([";".join(map(str, buildings)), f"{duration:.2f}"])

print("\nAll done ✔  Results in", CSV_PATH.resolve())
