#!/usr/bin/env python3
"""
Experiment pipeline:
1) Train protagonist (PPO or ATLA)
2) Evaluate against 4 adversaries: clean, random, mad, value
Runs across multiple seeds and GPUs.

RS special case: before evaluating RS, we train the RS critic (eval/train_rs.py),
then pass --rs_ckpt_path and --epsilon into eval.
"""

import os
import sys
import json
import argparse
import subprocess
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List

# --------------------------------------------------------------------------------------
# Path setup: this file is in project-root/pipeline/, so repo root is one directory up.
# Ensure we can import utils.py (which lives at project-root/utils.py) regardless of CWD.
# --------------------------------------------------------------------------------------
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Utilities available in utils.py:
from utils import _format_ckpt_dir, _arch_str, _prune_str  # noqa: F401

ATTACKS = ("clean", "random", "mad", "value")  # RS excluded here
ENV_EPSILON = {
    "hopper": 0.075,
    "walker2d": 0.05,
    "ant": 0.15,
    "halfcheetah": 0.15,
}

# ----------------------------- CLI -----------------------------

def parse_args():
    p = argparse.ArgumentParser(description="Full training+evaluation pipeline")
    # Core experiment knobs
    p.add_argument("--env-name", type=str, required=True)
    p.add_argument("--total-timesteps", type=str, default="5e7")
    p.add_argument("--num-seeds", type=int, default=5)
    p.add_argument("--start-seed", type=int, default=1,
                   help="First seed value; seeds will be start_seed..start_seed+num_seeds-1")
    p.add_argument("--algo", type=str, choices=["PPO", "ATLA"], default="PPO")
    p.add_argument("--arch", type=str, choices=["mlp", "rnn"], default="mlp")

    # SA-PPO (state adversarial regularization)
    p.add_argument(
        "--use-sa",
        action="store_true",
        default=True,
        help="Enable SA-PPO adversarial regularization during training (propagates to --use_sa_ppo)."
    )

    # Lottery Ticket Hypothesis rewind
    p.add_argument("--use_lth", action="store_true", default=False,
                   help="Enable LTH rewind (propagates to --use_lth). Can be combined with --use-sa.")

    # Pruning
    p.add_argument("--use-pruning", action="store_true", default=False)
    p.add_argument("--pruner-type", type=str, default="none")
    p.add_argument("--prune-percentage", type=float, default=0.0)

    # Training/eval script paths (defaults anchored to repo root)
    p.add_argument("--train-script", type=str, default=str(REPO_ROOT / "train_protagonist.py"))
    p.add_argument("--eval-script", type=str, default=str(REPO_ROOT / "eval" / "eval.py"))

    # RS critic training (used only when attack='rs')
    p.add_argument("--rs-train-script", type=str, default=str(REPO_ROOT / "eval" / "train_rs.py"),
                   help="Path to RS critic training script (eval/train_rs.py)")
    p.add_argument("--rs-lr", type=float, default=3e-4,
                   help="Learning rate passed to train_rs.py")
    p.add_argument("--rs-lambda", type=float, default=100.0,
                   help="RS regularization weight passed to train_rs.py")
    p.add_argument("--rs-epsilon", type=float, default=None,
                   help="Override epsilon for RS eval; defaults to ENV_EPSILON[env_name]")

    # Parallelism / GPUs
    p.add_argument("--gpus", type=str, default="0,1,2,3",
                   help="Comma-separated GPU IDs, e.g. '1,2,3'")

    # Output roots (defaults anchored to repo root)
    p.add_argument("--checkpoints-root", type=str, default=str(REPO_ROOT / "checkpoints"))
    p.add_argument("--results-root", type=str, default=str(REPO_ROOT / "results"))

    # Nice-to-haves
    p.add_argument("--wandb", action="store_true", default=False)
    p.add_argument("--layer-size", type=int, default=128 * 2)
    p.add_argument("--num-envs", type=int, default=2048)
    p.add_argument("--num-steps", type=int, default=10)
    p.add_argument("--update-epochs", type=int, default=4)
    p.add_argument("--num-minibatches", type=int, default=32)
    p.add_argument("--gamma", type=float, default=0.99)
    p.add_argument("--gae-lambda", type=float, default=0.95)
    p.add_argument("--lr", type=float, default=6e-4)
    p.add_argument("--adv-lr", type=float, default=1e-4)
    p.add_argument("--normalize-env", action="store_true", default=True)

    # Extra args passthrough (JSON dict) if you need to forward more flags
    p.add_argument("--extra-train-args", type=str, default="{}",
                   help='JSON dict of extra flags to pass to training script, e.g. \'{"anneal_lr": true}\'')

    return p.parse_args()

# ------------------------- Helper builders -------------------------

def build_train_cmd(python_bin: str, train_script: str, seed: int, args) -> List[str]:
    """Compose CLI for the training script using the argparse API."""
    use_atla = args.algo.upper() == "ATLA"
    use_rnn = args.arch.lower() == "rnn"
    use_pruning = args.use_pruning or (args.pruner_type.lower() != "none" or args.prune_percentage > 0.0)
    eps = ENV_EPSILON.get(args.env_name)
    if eps is None:
        raise ValueError(f"No epsilon configured for env '{args.env_name}'. Please add to ENV_EPSILON or pass overrides.")

    cmd = [
        python_bin, train_script,
        "--env_name", args.env_name,
        "--total_timesteps", str(args.total_timesteps),
        "--seed", str(seed),
        "--use_wandb",  # if args.wandb else "--no-use_wandb",
        "--use_rnn" if use_rnn else "--no-use_rnn",
        "--use_ATLA" if use_atla else "--no-use_ATLA",
        "--use_sa_ppo" if args.use_sa else "--no-use_sa_ppo",
        "--use_lth" if args.use_lth else "--no-use_lth",
        "--use_pruning" if use_pruning else "--no-use_pruning",
        "--pruner_type", str(args.pruner_type),
        "--prune_percentage", str(args.prune_percentage),
        "--layer_size", str(args.layer_size),
        "--num_envs", str(args.num_envs),
        "--num_steps", str(args.num_steps),
        "--update_epochs", str(args.update_epochs),
        "--num_minibatches", str(args.num_minibatches),
        "--gamma", str(args.gamma),
        "--gae_lambda", str(args.gae_lambda),
        "--lr", str(args.lr),
        "--adv_lr", str(args.adv_lr),
        "--normalize_env" if args.normalize_env else "",
        "--save_policy",
        "--adv_epsilon", str(eps),
        "--adv_eps", str(eps)
    ]
    cmd = [c for c in cmd if c != ""]

    # Optional extra args passthrough
    try:
        extra = json.loads(args.extra_train_args)
        for k, v in extra.items():
            flag = f"--{k}"
            if isinstance(v, bool):
                cmd.append(flag if v else f"--no-{k}")
            else:
                cmd.extend([flag, str(v)])
    except Exception as e:
        print(f"Warning: could not parse --extra-train-args JSON ({e}). Skipping.")

    return cmd


def expected_ckpt_dir(args, seed: int) -> Path:
    """Recreate the checkpoint directory name using utils._format_ckpt_dir logic."""
    cfg = {
        "ENV_NAME": args.env_name,
        "SEED": seed,
        "USE_PRUNING": (args.use_pruning or args.pruner_type.lower() != "none" or args.prune_percentage > 0.0),
        "PRUNER_TYPE": args.pruner_type,
        "PRUNE_PERCENTAGE": args.prune_percentage,
        "USE_RNN": args.arch.lower() == "rnn",
        "USE_SA_PPO": args.use_sa,
        "USE_LTH": args.use_lth,
    }
    return Path(args.checkpoints_root) / _format_ckpt_dir(cfg)


def build_eval_cmd(python_bin: str, eval_script: str, ckpt_dir: Path, out_dir: Path,
                   attack: str, use_rnn: bool, episodes: int = 1000, env_num: int = 2048) -> List[str]:
    """Compose CLI for the eval script."""
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "metrics.csv"
    out_pkl = out_dir / "metrics.pkl"

    cmd = [
        python_bin, eval_script,
        "--checkpoint-dir", str(ckpt_dir),
        "--num-episodes", str(episodes),
        "--env-num", str(env_num),
        "--attack", attack,
        "--out-pkl", str(out_pkl),
        "--out-csv", str(out_csv),
    ]
    cmd.append("--use_rnn" if use_rnn else "--no-use_rnn")
    return cmd


def result_dir(args, seed: int, attack: str) -> Path:
    pruner = _prune_str({
        "USE_PRUNING": (args.use_pruning or args.pruner_type.lower() != "none" or args.prune_percentage > 0.0),
        "PRUNER_TYPE": args.pruner_type,
    })
    pct = f"{args.prune_percentage:.2f}"
    arch = args.arch.lower()
    algo = args.algo.upper()
    sa = "on" if args.use_sa else "off"
    lth = "on" if args.use_lth else "off"

    return (Path(args.results_root) /
            args.env_name /
            algo /
            arch /
            f"sa={sa}" /
            f"lth={lth}" /
            f"prune={pruner}-{pct}" /
            f"seed{seed}" /
            attack)


def _derive_ckpt_root(args) -> Path:
    """Mirror train_protagonist.py checkpoint roots."""
    is_rnn = args.arch.lower() == "rnn"
    if is_rnn:
        sub = "rnn_checkpoints" if args.use_sa else "rnn_no_sa_checkpoints"
    else:
        sub = "checkpoints" if args.use_sa else "no_sa_checkpoints"
    return REPO_ROOT / sub

# ----------------------------- Worker -----------------------------

def run_one_seed(seed: int, gpu_id: str, args) -> Dict[str, str]:
    """
    Run training + 4 evaluations for one seed on a specific GPU.
    Returns a small dict with paths for logging.
    """
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

    python_bin = sys.executable  # use same interpreter

    # 1) Train
    train_cmd = build_train_cmd(python_bin, args.train_script, seed, args)
    print(f"[GPU {gpu_id}] Training seed={seed} -> {' '.join(train_cmd)}")
    subprocess.run(train_cmd, env=env, check=True)

    # 2) Evaluate (4 attacks)
    ckpt_dir = expected_ckpt_dir(args, seed)
    use_rnn = args.arch.lower() == "rnn"
    for attack in ATTACKS:
        out_dir = result_dir(args, seed, attack)
        eval_cmd = build_eval_cmd(python_bin, args.eval_script, ckpt_dir, out_dir, attack, use_rnn)

        if attack == "rs":
            # RS needs a trained critic first
            rs_train_script = Path(args.rs_train_script)
            if not rs_train_script.exists():
                raise FileNotFoundError(f"RS train script not found: {rs_train_script}")

            rs_train_cmd = [
                python_bin, str(rs_train_script),
                "--env_name", str(args.env_name),
                "--victim-checkpoint-dir", str(ckpt_dir),
                "--lr", str(args.rs_lr),
                "--rs_lambda", str(args.rs_lambda),
                "--save_policy",
                "--use_wandb", "false",
            ]
            rs_train_cmd.append("--use_rnn" if use_rnn else "--no-use_rnn")
            print(f"[GPU {gpu_id}] RS critic train seed={seed} -> {' '.join(map(str, rs_train_cmd))}")
            subprocess.run(rs_train_cmd, env=env, check=True)

            sarsa_root = REPO_ROOT / "sarsa_checkpoints"
            if not sarsa_root.exists():
                raise FileNotFoundError(f"{sarsa_root} not found after RS training")
            candidates = [d for d in sarsa_root.iterdir() if d.is_dir()]
            if not candidates:
                raise FileNotFoundError("No RS ckpt directories in sarsa_checkpoints/")
            latest_rs_ckpt = max(candidates, key=lambda d: d.stat().st_mtime)

            eps = args.rs_epsilon if args.rs_epsilon is not None else ENV_EPSILON.get(args.env_name)
            if eps is None:
                raise ValueError(f"No epsilon for env '{args.env_name}'. Provide --rs-epsilon explicitly.")

            eval_cmd.extend([
                "--rs_ckpt_path", str(latest_rs_ckpt),
            ])
        eps = ENV_EPSILON.get(args.env_name)
        if eps is None:
            raise ValueError(f"No epsilon configured for env '{args.env_name}'. Please add to ENV_EPSILON or pass overrides.")
        eval_cmd.extend([
            "--epsilon", str(eps),
        ])
        print(f"[GPU {gpu_id}] Eval seed={seed}, attack={attack} -> {' '.join(eval_cmd)}")
        subprocess.run(eval_cmd, env=env, check=True)

    return {
        "seed": str(seed),
        "gpu": str(gpu_id),
        "ckpt_dir": str(ckpt_dir),
        "results_root": str(result_dir(args, seed, ATTACKS[0]).parents[2])
    }

# ----------------------------- Main -----------------------------

def main():
    args = parse_args()
    # Ensure train/eval agree on checkpoint roots.
    args.checkpoints_root = str(_derive_ckpt_root(args))

    # Parse GPU list
    gpus = [g.strip() for g in args.gpus.split(",") if g.strip() != ""]
    if not gpus:
        raise ValueError("No GPUs provided. Use --gpus like '0' or '1,3'.")

    # Prepare seeds
    seeds = list(range(args.start_seed, args.start_seed + args.num_seeds))

    print(f"Running {len(seeds)} seeds on GPUs={gpus}")
    print(f"Env={args.env_name} | Algo={args.algo} | Arch={args.arch} | "
          f"SA={'on' if args.use_sa else 'off'} | "
          f"LTH={'on' if args.use_lth else 'off'} | "
          f"Pruning={'on' if (args.use_pruning or args.pruner_type.lower() != 'none' or args.prune_percentage > 0.0) else 'off'} "
          f"({args.pruner_type}, {args.prune_percentage:.2f})")
    print(f"Checkpoints root: {args.checkpoints_root}")
    print(f"Results root    : {args.results_root}")

    # Round-robin schedule across GPUs
    futures = []
    with ProcessPoolExecutor(max_workers=len(gpus)) as ex:
        for i, seed in enumerate(seeds):
            gpu_id = gpus[i % len(gpus)]
            futures.append(ex.submit(run_one_seed, seed, gpu_id, args))

        # Progress / results
        for f in as_completed(futures):
            try:
                info = f.result()
                print(f"✔ Finished seed={info['seed']} on GPU {info['gpu']}")
                print(f"  ├─ ckpt: {info['ckpt_dir']}")
                print(f"  └─ results: {info['results_root']}")
            except subprocess.CalledProcessError as e:
                print(f"✖ A subprocess failed (seed run). Return code {e.returncode}")
                print(f"  Command: {getattr(e, 'cmd', None)}")
                sys.exit(1)
            except Exception as e:
                print(f"✖ Unexpected error: {e}")
                sys.exit(1)

if __name__ == "__main__":
    main()
