#!/usr/bin/env python3
import subprocess
import sys
import os
import csv
from pathlib import Path
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed

# ====== EDIT THESE LISTS ======
LRS = [1e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 1e-3]
RS_LAMBDAS = [0, 0.1, 1, 10, 100]
gpus = [0,1, 2, 3,]  # host GPU IDs to use
# ==============================

RESULTS_CSV = "logging_hps.csv"


def run(cmd, env=None, **kwargs):
    """Run a shell command, streaming output live, raising on failure."""
    print(f"\n==> Running: {' '.join(map(str, cmd))}")
    subprocess.run(cmd, check=True, env=env, **kwargs)


def mean_return_from_eval_csv(csv_path: Path) -> float:
    """Compute mean of 'return' from eval CSV. Ignores any 'summary' row."""
    total, count = 0.0, 0
    with open(csv_path, "r", newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            ep = row.get("episode", "")
            if isinstance(ep, str) and ep.strip().lower() == "summary":
                continue
            try:
                r = float(row["return"])
            except Exception:
                continue
            total += r
            count += 1
    if count == 0:
        raise RuntimeError(f"No episode returns found in {csv_path}")
    return total / count


def worker(job_idx, lr, rs_lambda, repo_root: Path, base_env: dict, gpu_id: int):
    """Train RS critic and evaluate, pinned to a specific GPU. Returns a result row."""
    train_script = repo_root / "train_rs.py"
    eval_script = repo_root / "eval.py"
    victim_ckpt = repo_root / "../checkpoints/hopper-seed5-prune=none-0.90-mlp"

    if not train_script.exists():
        return ("ERROR", lr, rs_lambda, f"train_rs.py not found at {train_script}", "", "")

    if not eval_script.exists():
        return ("ERROR", lr, rs_lambda, f"eval.py not found at {eval_script}", "", "")

    if not victim_ckpt.exists():
        return ("ERROR", lr, rs_lambda, f"victim checkpoint dir not found: {victim_ckpt}", "", "")

    out_dir = repo_root / "sweep_outputs"
    out_dir.mkdir(exist_ok=True)

    # Pin to specific GPU. Note: within the process this becomes "device 0".
    env = dict(base_env)
    env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    # (Optional) keep JAX from preallocating all memory:
    # env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

    # ------------------- TRAINING -------------------
    print("\n" + "=" * 80)
    print(f"[GPU {gpu_id}] SWEEP COMBO: LR={lr} | RS_LAMBDA={rs_lambda}")
    print("=" * 80)

    # Unique tag so checkpoint/eval files don't collide across workers
    tag = datetime.now().strftime("%Y%m%d_%H%M%S") + f"_g{gpu_id}_j{job_idx}"

    train_args = [
        sys.executable, str(train_script),
        "--victim-checkpoint-dir", str(victim_ckpt),
        "--lr", str(lr),
        "--rs_lambda", str(rs_lambda),
        "--save_policy",
        "--use_wandb", "false",
        # e.g., shorten if needed:
        # "--total_timesteps", "1e6",
        # If train_rs.py supports an explicit output dir, pass one here (recommended):
        # "--sarsa_out_dir", str(repo_root / "sarsa_checkpoints" / f"ckpt_{tag}")
    ]
    try:
        run(train_args, env=env)
    except subprocess.CalledProcessError as e:
        return ("ERROR", lr, rs_lambda, f"training failed: {e}", "", "")

    # ------------------- FIND LATEST RS CKPT (just made by this training) -------------------
    sarsa_root = repo_root / "sarsa_checkpoints"
    if not sarsa_root.exists():
        return ("ERROR", lr, rs_lambda, "sarsa_checkpoints/ not found — did training save the critic?", "", "")

    ckpt_dirs = [d for d in sarsa_root.iterdir() if d.is_dir()]
    if not ckpt_dirs:
        return ("ERROR", lr, rs_lambda, "No checkpoint directories found in sarsa_checkpoints/", "", "")

    # Choose the most recent dir by mtime (safe if train writes one new dir per run)
    latest_ckpt = max(ckpt_dirs, key=lambda d: d.stat().st_mtime)
    print(f"\n[GPU {gpu_id}] ==> Using RS critic checkpoint directory:\n    {latest_ckpt}")

    # ------------------- EVALUATION -------------------
    out_csv = out_dir / f"eval_rs_lr{lr}_lam{rs_lambda}_{tag}.csv"
    out_pkl = out_dir / f"eval_rs_lr{lr}_lam{rs_lambda}_{tag}.pkl"

    eval_args = [
        sys.executable, str(eval_script),
        "--checkpoint-dir", str(victim_ckpt),
        "--attack", "rs",
        "--rs_ckpt_path", str(latest_ckpt),
        "--epsilon", "0.075",
        "--num-episodes", "500",
        "--env-num", "2048",
        "--out-csv", str(out_csv),
        "--out-pkl", str(out_pkl),
        # Optional, for harsher/less noisy eval:
        # "--rs_pgd_steps", "100",
        # "--deterministic_eval",
    ]
    try:
        run(eval_args, env=env)
    except subprocess.CalledProcessError as e:
        return ("ERROR", lr, rs_lambda, f"evaluation failed: {e}", str(out_csv), str(latest_ckpt))

    # ------------------- COLLECT RESULT -------------------
    try:
        m = mean_return_from_eval_csv(out_csv)
    except Exception as e:
        print(f"[GPU {gpu_id}] [WARN] Failed to read mean from {out_csv}: {e}")
        m = float("nan")

    print(f"[GPU {gpu_id}] ==> Mean return for LR={lr}, RS_LAMBDA={rs_lambda}: {m:.3f}")
    return ("OK", lr, rs_lambda, f"{m:.6f}", str(out_csv), str(latest_ckpt))


def main():
    repo_root = Path(__file__).resolve().parent

    # Prepare results CSV header once
    write_header = not Path(RESULTS_CSV).exists()
    if write_header:
        with open(RESULTS_CSV, "w", newline="", encoding="utf-8") as rf:
            writer = csv.writer(rf)
            writer.writerow(["timestamp", "lr", "rs_lambda", "mean_return", "eval_csv", "rs_ckpt"])

    # Build all jobs
    jobs = []
    idx = 0
    for lr in LRS:
        for rs_lambda in RS_LAMBDAS:
            jobs.append((idx, lr, rs_lambda))
            idx += 1

    base_env = dict(os.environ)

    # Run up to len(gpus) jobs in parallel (round-robin assignment)
    results = []
    with ProcessPoolExecutor(max_workers=len(gpus)) as ex:
        futures = []
        for j, (job_idx, lr, rs_lambda) in enumerate(jobs):
            gpu_id = gpus[j % len(gpus)]
            futures.append(ex.submit(worker, job_idx, lr, rs_lambda, repo_root, base_env, gpu_id))

        for fut in as_completed(futures):
            results.append(fut.result())

    # Write results at the end (avoids concurrent writes)
    with open(RESULTS_CSV, "a", newline="", encoding="utf-8") as rf:
        writer = csv.writer(rf)
        for status, lr, rs_lambda, mean_or_err, out_csv, ckpt in results:
            ts = datetime.now().isoformat()
            if status == "OK":
                writer.writerow([ts, lr, rs_lambda, mean_or_err, out_csv, ckpt])
            else:
                # record errors with mean_return = 'NaN' and message in eval_csv column
                writer.writerow([ts, lr, rs_lambda, "NaN", f"[{status}] {mean_or_err}", ckpt])

    print("\n==> Sweep complete.")
    print(f"Results: {RESULTS_CSV}")
    print(f"Per-run outputs: {repo_root / 'sweep_outputs'}")


if __name__ == "__main__":
    main()
