#!/usr/bin/env python3
"""Shard prompts across GPUs, run Stage-2 per shard, then merge outputs."""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import List


def _load_prompts(path: Path) -> List[object]:
    payload = json.loads(path.read_text(encoding="utf-8"))
    if isinstance(payload, dict) and "prompts" in payload:
        payload = payload.get("prompts") or []
    if not isinstance(payload, list):
        raise ValueError(f"prompts_file must be JSON list (or dict with 'prompts'): {path}")
    return payload


def _atomic_write_json(path: Path, obj: object) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + ".tmp")
    tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
    tmp.replace(path)


def _split_round_robin(items: List[object], n: int) -> List[List[object]]:
    shards = [[] for _ in range(n)]
    for idx, item in enumerate(items):
        shards[idx % n].append(item)
    return shards


def _merge_jsonl(shard_outputs: List[Path], output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    tmp = output_path.with_suffix(output_path.suffix + ".tmp")
    with tmp.open("w", encoding="utf-8") as out_f:
        for shard_path in shard_outputs:
            if not shard_path.exists():
                raise FileNotFoundError(f"missing shard output: {shard_path}")
            with shard_path.open("r", encoding="utf-8") as in_f:
                for line in in_f:
                    if line.strip():
                        out_f.write(line)
    tmp.replace(output_path)


def main() -> None:
    parser = argparse.ArgumentParser(description="Shard Stage-2 runs across multiple GPUs.")
    parser.add_argument("--runner", default="scripts/run_stage2_joint_cbf.py")
    parser.add_argument("--prompts_file", required=True)
    parser.add_argument("--output_path", required=True)
    parser.add_argument("--num_shards", type=int, default=8)
    parser.add_argument("--gpus", default=None, help="Comma-separated GPU ids (default: 0..num_shards-1)")
    parser.add_argument("--summary_script", default="scripts/summarize_stage2_anygate_v1.py")
    parser.add_argument("--merge_only", action="store_true")
    parser.add_argument("--overwrite", action="store_true")
    parser.add_argument("--detectors_config", default=None)
    parser.add_argument("--controller_config", default=None)
    parser.add_argument("--selected_dims", default=None)
    parser.add_argument("--scorer_dir", default=None)
    parser.add_argument("--base_model", default=None)
    parser.add_argument("--max_prompts", type=int, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=None)
    parser.add_argument("--continuous_steps", type=int, default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--device", default=None)
    parser.add_argument("--any_gate_enabled", action="store_true")
    parser.add_argument("--gate_fpr", default=None)
    parser.add_argument("--refusal_gate_enabled", action="store_true")
    parser.add_argument("--min_tstar_resp", type=int, default=None)
    parser.add_argument("--terminal_thresholds_path", default=None)
    parser.add_argument("--kterm_topm", type=int, default=None)
    parser.add_argument("--kterm_min_prob", type=float, default=None)
    parser.add_argument("runner_args", nargs=argparse.REMAINDER)
    args = parser.parse_args()

    if any(arg in args.runner_args for arg in ("--prompts_file", "--output_path")):
        raise ValueError("Do not pass --prompts_file/--output_path in runner_args; use top-level args.")
    if args.runner_args and args.runner_args[0] == "--":
        args.runner_args = args.runner_args[1:]

    def _flag_present(flag: str) -> bool:
        return flag in args.runner_args

    def _append_flag(flag: str, value: object | None) -> None:
        if value is None or _flag_present(flag):
            return
        args.runner_args.extend([flag, str(value)])

    def _append_bool(flag: str, enabled: bool) -> None:
        if enabled and not _flag_present(flag):
            args.runner_args.append(flag)

    _append_flag("--detectors_config", args.detectors_config)
    _append_flag("--controller_config", args.controller_config)
    _append_flag("--selected_dims", args.selected_dims)
    _append_flag("--scorer_dir", args.scorer_dir)
    _append_flag("--base_model", args.base_model)
    _append_flag("--max_prompts", args.max_prompts)
    _append_flag("--max_new_tokens", args.max_new_tokens)
    _append_flag("--continuous_steps", args.continuous_steps)
    _append_flag("--seed", args.seed)
    _append_flag("--device", args.device)
    _append_flag("--gate_fpr", args.gate_fpr)
    _append_flag("--min_tstar_resp", args.min_tstar_resp)
    _append_flag("--terminal_thresholds_path", args.terminal_thresholds_path)
    _append_flag("--kterm_topm", args.kterm_topm)
    _append_flag("--kterm_min_prob", args.kterm_min_prob)
    _append_bool("--any_gate_enabled", args.any_gate_enabled)
    _append_bool("--refusal_gate_enabled", args.refusal_gate_enabled)

    num_shards_flag = "--num_shards" in sys.argv
    env_gpus = os.environ.get("CUDA_VISIBLE_DEVICES")
    if args.gpus:
        gpus = [g.strip() for g in args.gpus.split(",") if g.strip()]
    elif env_gpus:
        gpus = [g.strip() for g in env_gpus.split(",") if g.strip()]
    else:
        gpus = [str(i) for i in range(args.num_shards)]
    if not num_shards_flag and gpus:
        args.num_shards = len(gpus)
    if args.num_shards <= 0:
        raise ValueError(f"--num_shards must be positive (got {args.num_shards})")

    output_path = Path(args.output_path)
    if output_path.exists() and not args.overwrite:
        raise FileExistsError(f"output_path exists: {output_path} (use --overwrite)")

    shard_root = output_path.parent / f"{output_path.stem}_shards"
    shard_root.mkdir(parents=True, exist_ok=True)

    if len(gpus) < args.num_shards:
        raise ValueError(f"Need {args.num_shards} GPUs, got {len(gpus)} (gpus={gpus})")

    shard_outputs = [shard_root / f"output_shard_{i}.jsonl" for i in range(args.num_shards)]

    if not args.merge_only:
        prompts = _load_prompts(Path(args.prompts_file))
        shards = _split_round_robin(prompts, args.num_shards)
        shard_prompts = []
        for i, shard in enumerate(shards):
            shard_path = shard_root / f"prompts_shard_{i}.json"
            _atomic_write_json(shard_path, shard)
            shard_prompts.append(shard_path)

        procs = []
        for i in range(args.num_shards):
            env = os.environ.copy()
            env["CUDA_VISIBLE_DEVICES"] = gpus[i]
            cmd = [
                sys.executable,
                args.runner,
                "--prompts_file",
                str(shard_prompts[i]),
                "--output_path",
                str(shard_outputs[i]),
            ] + args.runner_args
            procs.append(subprocess.Popen(cmd, env=env))

        exit_codes = [p.wait() for p in procs]
        if any(code != 0 for code in exit_codes):
            raise RuntimeError(f"Shard run failed: {exit_codes}")

    _merge_jsonl(shard_outputs, output_path)

    summary_script = Path(args.summary_script) if args.summary_script else None
    if summary_script and summary_script.exists():
        summary_out = output_path.with_suffix(output_path.suffix + ".summary.json")
        subprocess.check_call(
            [sys.executable, str(summary_script), "--input_jsonl", str(output_path), "--output_json", str(summary_out)]
        )

    print(f"merged_output={output_path}")


if __name__ == "__main__":
    main()
