#!/usr/bin/env python3
"""
eval_ckpts.py

Supports two modes:
1. Local mode: Assumes merged HF checkpoints already exist locally.
2. HuggingFace mode: Downloads checkpoints from HuggingFace, merges them, and evaluates.

For each checkpoint:
  1) Download from HuggingFace if needed and merge into HF format
  2) Start vLLM OpenAI server for that ckpt
  3) Run run_generator_fixer_flow.py to evaluate on pregenerated bugs
     using multi-dataset mode: --val_datasets <spec1> <spec2> ...
  4) Stop server
  5) Write per-step logs/results + a summary.json

Example (local mode):
python scripts/eval_ckpts.py \
  --parent_ckpt_dir ~/rllm/checkpoints/...-merged \
  --ckpt_subdir actor_merged \
  --outdir ./eval_outputs/pregen_eval \
  --run_generator_fixer_flow_py examples/bugs_refactor/run_generator_fixer_flow.py \
  --val_datasets "bugbench_human:test_small bigcodebench:test_small" \
  --n_tasks 256 --n_parallel 32 \
  --tp 1 --max_model_len 16384 --gpu_mem_util 0.90 \
  --fixer_attempts_val 1 \
  --temperature 0.0 --top_p 1.0 \
  --extra_flow_args "--include_failed_test_output"

Example (HuggingFace mode):
python examples/bugs_refactor/eval_ckpts.py \
  --hf_repo anonymous/selfplay_large_ckpts \
  --base_model Qwen/Qwen2.5-Coder-7B-Instruct \
  --ckpt_cache_dir /data/user/selfplay_large_ckpts \
  --outdir ./eval_outputs/selfplay_large \
  --run_generator_fixer_flow_py examples/bugs_refactor/run_generator_fixer_flow.py \
  --val_datasets "bugbench_gpt-oss-20b_sampled:test_small" \
  --n_tasks 2000 --n_parallel 32 \
  --tp 1 --max_model_len 16384 --gpu_mem_util 0.90 --fixer_attempts_val 1 --temperature 0.6 --top_p 0.95 \
  --extra_flow_args "--include_failed_test_output --eval_pregenerated_only"

python examples/bugs_refactor/eval_ckpts.py \
  --hf_repo anonymous/fixer_large_ckpts \
  --base_model Qwen/Qwen2.5-Coder-7B-Instruct \
  --ckpt_cache_dir /data/user/fixer_large_ckpts \
  --outdir ./eval_outputs/fixer_large \
  --run_generator_fixer_flow_py examples/bugs_refactor/run_generator_fixer_flow.py \
  --val_datasets "bugbench:test_small" \
  --n_tasks 2000 --n_parallel 32 \
  --tp 1 --max_model_len 16384 --gpu_mem_util 0.90 --fixer_attempts_val 1 --temperature 0.6 --top_p 0.95 \
  --extra_flow_args "--include_failed_test_output --eval_pregenerated_only"

Example (single checkpoint mode):
python examples/bugs_refactor/eval_ckpts.py \
  --single_ckpt /data/user/selfplay_large_ckpts/repo_raw/global_step_20/actor_merged \
  --outdir ./eval_outputs/single_eval \
  --run_generator_fixer_flow_py examples/bugs_refactor/run_generator_fixer_flow.py \
  --val_datasets "bugbench_gpt-oss-20b_sampled:test_small" \
  --n_tasks 2000 --n_parallel 32
"""

from __future__ import annotations

import argparse
import glob
import json
import os
import re
import shlex
import signal
import socket
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set


# -----------------------------
# Utilities
# -----------------------------

def extract_step(path: str) -> int:
    m = re.search(r"(global_step|step)[\-_]?(\d+)", path)
    return int(m.group(2)) if m else 0


def find_ckpts(parent_ckpt_dir: str, ckpt_subdir: str) -> List[str]:
    parent = os.path.abspath(parent_ckpt_dir)
    pat = os.path.join(parent, "global_step_*", ckpt_subdir)
    ckpts = sorted(glob.glob(pat))

    if not ckpts:
        pat2 = os.path.join(parent, "*", ckpt_subdir)
        ckpts = sorted(glob.glob(pat2))

    if not ckpts:
        raise FileNotFoundError(f"No checkpoints matched patterns:\n  {pat}\n  {pat2}")

    return sorted(ckpts, key=extract_step)


def is_port_open(host: str, port: int, timeout_s: float = 0.25) -> bool:
    try:
        with socket.create_connection((host, port), timeout=timeout_s):
            return True
    except OSError:
        return False


def kill_process_on_port(port: int) -> bool:
    """Try to kill any process using the given port. Returns True if successful."""
    try:
        # Try using fuser (Linux)
        result = subprocess.run(
            ["fuser", "-k", f"{port}/tcp"],
            capture_output=True,
            timeout=10,
        )
        if result.returncode == 0:
            print(f"  [kill] Killed process on port {port} using fuser")
            time.sleep(2)  # Give it time to clean up
            return True
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass
    
    try:
        # Try using lsof + kill (macOS/Linux fallback)
        result = subprocess.run(
            ["lsof", "-t", "-i", f":{port}"],
            capture_output=True,
            text=True,
            timeout=5,
        )
        if result.returncode == 0 and result.stdout.strip():
            pids = result.stdout.strip().split('\n')
            for pid in pids:
                try:
                    os.kill(int(pid), signal.SIGKILL)
                    print(f"  [kill] Killed PID {pid} on port {port}")
                except (ValueError, OSError):
                    pass
            time.sleep(2)
            return True
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass
    
    return False


def wait_port(host: str, port: int, timeout_s: float) -> bool:
    t0 = time.time()
    while time.time() - t0 < timeout_s:
        if is_port_open(host, port):
            return True
        time.sleep(0.25)
    return False


def ensure_dir(p: str) -> None:
    os.makedirs(p, exist_ok=True)


def kill_process_tree(p: subprocess.Popen, grace_s: float = 10.0) -> None:
    if p.poll() is not None:
        return

    try:
        os.killpg(os.getpgid(p.pid), signal.SIGINT)
    except Exception:
        try:
            p.send_signal(signal.SIGINT)
        except Exception:
            pass

    t0 = time.time()
    while time.time() - t0 < grace_s:
        if p.poll() is not None:
            return
        time.sleep(0.25)

    try:
        os.killpg(os.getpgid(p.pid), signal.SIGTERM)
    except Exception:
        try:
            p.terminate()
        except Exception:
            pass

    t0 = time.time()
    while time.time() - t0 < 5.0:
        if p.poll() is not None:
            return
        time.sleep(0.25)

    try:
        os.killpg(os.getpgid(p.pid), signal.SIGKILL)
    except Exception:
        try:
            p.kill()
        except Exception:
            pass

def find_repo_root(start: str) -> str:
    """
    Walk upward to find a directory that looks like the repo root,
    i.e., contains 'examples/'.
    """
    cur = os.path.abspath(start)
    for _ in range(10):
        if os.path.isdir(os.path.join(cur, "examples")):
            return cur
        parent = os.path.dirname(cur)
        if parent == cur:
            break
        cur = parent
    raise RuntimeError(
        "Could not find repo root (a directory containing 'examples/'). "
        "Run from inside the repo or adjust find_repo_root()."
    )


def env_with_repo_on_pythonpath(repo_root: str) -> dict:
    env = os.environ.copy()
    old = env.get("PYTHONPATH", "")
    # Add both repo root and verl package directory to PYTHONPATH
    verl_dir = os.path.join(repo_root, "verl")
    paths = [repo_root]
    if os.path.isdir(verl_dir):
        paths.append(verl_dir)
    new_path = os.pathsep.join(paths)
    env["PYTHONPATH"] = new_path + (os.pathsep + old if old else "")
    return env


# -----------------------------
# HuggingFace checkpoint handling
# -----------------------------

def list_hf_revisions(repo_id: str, local_cache_dir: Optional[str] = None) -> List[Dict]:
    """
    List all revisions/checkpoints in a HuggingFace repository.
    Returns a list of dicts with 'commit_id', 'ref_name', 'step', and optionally 'subdir'.
    
    Handles multiple structures:
    1. Branches/tags with step numbers in names
    2. Single repo with step subdirectories (e.g., global_step_10/, global_step_20/)
    3. Commits with step numbers in messages
    """
    from huggingface_hub import HfApi
    
    api = HfApi()
    revisions = []
    
    # First try to get refs (branches and tags) - these often have step info in names
    try:
        refs = api.list_repo_refs(repo_id=repo_id, repo_type="model")
        
        # Process branches
        for branch in refs.branches:
            name = branch.name
            step = extract_step(name)
            revisions.append({
                "commit_id": branch.target_commit,
                "ref_name": name,
                "step": step,
                "source": "branch",
            })
        
        # Process tags
        for tag in refs.tags:
            name = tag.name
            step = extract_step(name)
            revisions.append({
                "commit_id": tag.target_commit,
                "ref_name": name,
                "step": step,
                "source": "tag",
            })
    except Exception as e:
        print(f"  [warn] Could not list refs: {e}")
    
    # If we found refs with step info, use those
    refs_with_steps = [r for r in revisions if r["step"] > 0]
    if refs_with_steps:
        # Filter out duplicates and sort by step
        seen_steps = set()
        unique_revisions = []
        for r in sorted(refs_with_steps, key=lambda x: x["step"]):
            if r["step"] not in seen_steps:
                seen_steps.add(r["step"])
                unique_revisions.append(r)
        return unique_revisions
    
    # Check if this is a single-download repo with step subdirectories
    # Try to list files in the repo to detect subdirectory structure
    try:
        files = api.list_repo_files(repo_id=repo_id, repo_type="model")
        subdirs = set()
        for f in files:
            parts = f.split("/")
            if len(parts) > 1:
                subdir = parts[0]
                step = extract_step(subdir)
                if step > 0:
                    subdirs.add((subdir, step))
        
        if subdirs:
            print(f"  [info] Found {len(subdirs)} step subdirectories in repo")
            # Get the main branch commit for downloading
            main_commit = None
            for rev in revisions:
                if rev.get("ref_name") == "main":
                    main_commit = rev["commit_id"]
                    break
            if not main_commit and revisions:
                main_commit = revisions[0]["commit_id"]
            
            revisions = []
            for subdir, step in sorted(subdirs, key=lambda x: x[1]):
                revisions.append({
                    "commit_id": main_commit,
                    "ref_name": subdir,
                    "step": step,
                    "source": "subdir",
                    "subdir": subdir,
                })
            return revisions
    except Exception as e:
        print(f"  [warn] Could not list repo files: {e}")
    
    # Fall back to commits if no refs with steps found
    print("  [info] No refs with step info found, falling back to commits...")
    commits = api.list_repo_commits(repo_id=repo_id, repo_type="model")
    
    revisions = []
    for i, commit in enumerate(reversed(list(commits))):  # Oldest first
        title = commit.title or ""
        step = extract_step(title)
        # If no step in title, use commit index as step
        if step == 0 and i > 0:
            step = i
        revisions.append({
            "commit_id": commit.commit_id,
            "ref_name": title or commit.commit_id[:8],
            "step": step,
            "source": "commit",
            "created_at": str(commit.created_at) if commit.created_at else None,
        })
    
    # Sort by step number
    revisions.sort(key=lambda x: x["step"])
    return revisions


def check_existing_checkpoint(cache_dir: str, step: int, subdir: Optional[str] = None) -> Optional[str]:
    """
    Check if we already have a usable checkpoint for this step.
    Returns the path to the usable checkpoint, or None if not found.
    """
    step_dir = os.path.join(cache_dir, f"global_step_{step:06d}")
    
    # Check merged directory first (preferred)
    merged_dir = os.path.join(step_dir, "actor_merged")
    if os.path.exists(merged_dir) and os.path.exists(os.path.join(merged_dir, "config.json")):
        return merged_dir
    
    # Check raw directory for HF format
    raw_dir = os.path.join(step_dir, "raw")
    if os.path.exists(raw_dir) and os.path.exists(os.path.join(raw_dir, "config.json")):
        return raw_dir
    
    # Check raw/actor for HF format
    raw_actor_dir = os.path.join(raw_dir, "actor")
    if os.path.exists(raw_actor_dir) and os.path.exists(os.path.join(raw_actor_dir, "config.json")):
        return raw_actor_dir
    
    # Check for repo_raw with subdirectories (for repos that have all steps in one download)
    if subdir:
        repo_raw_subdir = os.path.join(cache_dir, "repo_raw", subdir)
        if os.path.exists(repo_raw_subdir):
            # Check for config.json directly
            if os.path.exists(os.path.join(repo_raw_subdir, "config.json")):
                return repo_raw_subdir
            # Check for actor_merged subdirectory
            actor_merged = os.path.join(repo_raw_subdir, "actor_merged")
            if os.path.exists(actor_merged) and os.path.exists(os.path.join(actor_merged, "config.json")):
                return actor_merged
    
    # Also check for global_step_X naming pattern in repo_raw
    step_subdir_names = [f"global_step_{step}", f"step_{step}", f"global_step{step}", f"step{step}"]
    for name in step_subdir_names:
        repo_raw_step = os.path.join(cache_dir, "repo_raw", name)
        if os.path.exists(repo_raw_step):
            # Check for config.json directly
            if os.path.exists(os.path.join(repo_raw_step, "config.json")):
                return repo_raw_step
            # Check for actor_merged subdirectory
            actor_merged = os.path.join(repo_raw_step, "actor_merged")
            if os.path.exists(actor_merged) and os.path.exists(os.path.join(actor_merged, "config.json")):
                return actor_merged
    
    return None


def download_hf_checkpoint(
    repo_id: str,
    revision: str,
    local_dir: str,
    step: int,
    subdir: Optional[str] = None,
) -> str:
    """
    Download a specific revision of a HuggingFace checkpoint.
    Returns the path to the downloaded checkpoint directory.
    
    Args:
        subdir: If provided, the repo contains multiple steps in subdirectories.
                Download repo once to shared location, return path to specific subdir.
    """
    from huggingface_hub import snapshot_download
    
    if subdir:
        # Repo has multiple steps in subdirectories - download once to shared location
        shared_raw_dir = os.path.join(local_dir, "repo_raw")
        
        if not (os.path.exists(shared_raw_dir) and os.listdir(shared_raw_dir)):
            ensure_dir(shared_raw_dir)
            print(f"  [download] Downloading full repo to {shared_raw_dir}...")
            snapshot_download(
                repo_id=repo_id,
                revision=revision,
                local_dir=shared_raw_dir,
                local_dir_use_symlinks=False,
            )
        else:
            print(f"  [skip download] Using cached repo at {shared_raw_dir}")
        
        # Return path to the specific step subdirectory
        ckpt_dir = os.path.join(shared_raw_dir, subdir)
        if not os.path.exists(ckpt_dir):
            raise FileNotFoundError(f"Subdirectory {subdir} not found in {shared_raw_dir}")
        return ckpt_dir
    else:
        # Each revision is a separate download
        step_dir = os.path.join(local_dir, f"global_step_{step:06d}")
        raw_ckpt_dir = os.path.join(step_dir, "raw")
        
        if os.path.exists(raw_ckpt_dir) and os.listdir(raw_ckpt_dir):
            print(f"  [skip download] Using cached checkpoint at {raw_ckpt_dir}")
            return raw_ckpt_dir
        
        ensure_dir(raw_ckpt_dir)
        
        print(f"  [download] Downloading revision {revision[:8]}... to {raw_ckpt_dir}")
        snapshot_download(
            repo_id=repo_id,
            revision=revision,
            local_dir=raw_ckpt_dir,
            local_dir_use_symlinks=False,
        )
        
        return raw_ckpt_dir


def merge_fsdp_checkpoint(
    raw_ckpt_dir: str,
    merged_dir: str,
    base_model_path: str,
    trust_remote_code: bool = True,
) -> str:
    """
    Merge FSDP checkpoint shards into a single HuggingFace model if needed.
    Returns path to the usable model directory (may be raw_ckpt_dir if already in HF format).
    """
    # Check if already in HF format - no merge needed, use directly
    if os.path.exists(os.path.join(raw_ckpt_dir, "config.json")):
        print(f"  [skip merge] Checkpoint already in HF format at {raw_ckpt_dir}")
        return raw_ckpt_dir
    
    # Check if we already have a merged version cached
    if os.path.exists(merged_dir) and os.path.exists(os.path.join(merged_dir, "config.json")):
        print(f"  [skip merge] Using cached merged model at {merged_dir}")
        return merged_dir
    
    # Check for actor_merged subdirectory (already merged checkpoints)
    actor_merged_dir = os.path.join(raw_ckpt_dir, "actor_merged")
    if os.path.exists(actor_merged_dir) and os.path.exists(os.path.join(actor_merged_dir, "config.json")):
        print(f"  [skip merge] Found actor_merged subdirectory in HF format at {actor_merged_dir}")
        return actor_merged_dir
    
    # Check if this is an FSDP checkpoint (has fsdp_config.json)
    fsdp_config_path = os.path.join(raw_ckpt_dir, "fsdp_config.json")
    
    if os.path.exists(fsdp_config_path):
        print(f"  [merge] Merging FSDP checkpoint from {raw_ckpt_dir}...")
        ensure_dir(merged_dir)
        _merge_fsdp_to_hf(raw_ckpt_dir, merged_dir, base_model_path, trust_remote_code)
        return merged_dir
    
    # Check for actor subdirectory with FSDP checkpoint
    actor_dir = os.path.join(raw_ckpt_dir, "actor")
    if os.path.exists(actor_dir) and os.path.exists(os.path.join(actor_dir, "fsdp_config.json")):
        print(f"  [merge] Found actor subdirectory, merging FSDP checkpoint...")
        ensure_dir(merged_dir)
        _merge_fsdp_to_hf(actor_dir, merged_dir, base_model_path, trust_remote_code)
        return merged_dir
    
    # Check for actor subdirectory already in HF format
    if os.path.exists(actor_dir) and os.path.exists(os.path.join(actor_dir, "config.json")):
        print(f"  [skip merge] Actor checkpoint already in HF format at {actor_dir}")
        return actor_dir
    
    # List directory contents for debugging
    contents = []
    if os.path.exists(raw_ckpt_dir):
        contents = os.listdir(raw_ckpt_dir)
    raise ValueError(
        f"Unknown checkpoint format at {raw_ckpt_dir}. "
        f"Expected either fsdp_config.json or config.json (checked root, actor/, and actor_merged/ subdirectories). "
        f"Directory contains: {contents[:20]}"  # Limit to first 20 items
    )


def _merge_fsdp_to_hf(
    local_dir: str,
    target_dir: str,
    base_model_path: str,
    trust_remote_code: bool = True,
):
    """
    Use the FSDP model merger to merge checkpoint shards.
    """
    # Add verl to path for imports
    repo_root = find_repo_root(os.path.dirname(os.path.abspath(__file__)))
    verl_dir = os.path.join(repo_root, "verl")
    if verl_dir not in sys.path:
        sys.path.insert(0, verl_dir)
    
    from verl.model_merger.base_model_merger import ModelMergerConfig
    from verl.model_merger.fsdp_model_merger import FSDPModelMerger
    
    # The hf_model_config_path should point to where the model config is
    # This could be in the checkpoint dir or we use the base model
    hf_config_path = os.path.join(local_dir, "huggingface")
    if not os.path.exists(hf_config_path):
        hf_config_path = base_model_path
    
    config = ModelMergerConfig(
        operation="merge",
        backend="fsdp",
        local_dir=local_dir,
        target_dir=target_dir,
        hf_model_config_path=hf_config_path,
        trust_remote_code=trust_remote_code,
    )
    
    merger = FSDPModelMerger(config)
    merger.merge_and_save()
    
    print(f"  [merge] Saved merged model to {target_dir}")


def get_hf_checkpoints(
    repo_id: str,
    cache_dir: str,
    base_model_path: str,
    steps: Optional[Set[int]] = None,
    min_step: int = -1,
    max_step: int = -1,
    trust_remote_code: bool = True,
) -> List[str]:
    """
    Download and merge checkpoints from HuggingFace, return list of merged checkpoint paths.
    """
    print(f"\n=== Fetching checkpoints from HuggingFace: {repo_id} ===")
    
    # Check for previously downloaded data at old location and migrate if needed
    old_raw_dir = os.path.join(cache_dir, "global_step_000000", "raw")
    new_raw_dir = os.path.join(cache_dir, "repo_raw")
    if os.path.exists(old_raw_dir) and not os.path.exists(new_raw_dir):
        # Check if old location has step subdirectories
        old_contents = os.listdir(old_raw_dir) if os.path.isdir(old_raw_dir) else []
        step_subdirs = [d for d in old_contents if extract_step(d) > 0]
        if step_subdirs:
            print(f"  [migrate] Moving previously downloaded repo from {old_raw_dir} to {new_raw_dir}")
            import shutil
            shutil.move(old_raw_dir, new_raw_dir)
    
    revisions = list_hf_revisions(repo_id)
    print(f"Found {len(revisions)} revisions in repository")
    for rev in revisions:
        print(f"  - step {rev['step']}: {rev.get('ref_name', rev['commit_id'][:8])} ({rev.get('source', 'unknown')})")
    
    # Filter revisions based on step criteria
    filtered_revisions = []
    for rev in revisions:
        step = rev["step"]
        if step == 0:
            # Skip revisions without step info unless explicitly requested
            if steps is not None and 0 not in steps:
                continue
        if steps is not None and step not in steps:
            continue
        if min_step >= 0 and step < min_step:
            continue
        if max_step >= 0 and step > max_step:
            continue
        filtered_revisions.append(rev)
    
    if not filtered_revisions:
        raise RuntimeError(
            f"No revisions matched filters. "
            f"Available steps: {[r['step'] for r in revisions]}"
        )
    
    print(f"Will process {len(filtered_revisions)} revisions: {[r['step'] for r in filtered_revisions]}")
    
    checkpoint_paths = []
    for rev in filtered_revisions:
        step = rev["step"]
        commit_id = rev["commit_id"]
        subdir = rev.get("subdir")  # For repos with step subdirectories
        
        print(f"\n--- Processing step {step} ({rev.get('ref_name', commit_id[:8])}) ---")
        
        # First check if we already have a usable checkpoint
        existing_ckpt = check_existing_checkpoint(cache_dir, step, subdir=subdir)
        if existing_ckpt:
            print(f"  [cached] Using existing checkpoint at {existing_ckpt}")
            checkpoint_paths.append(existing_ckpt)
            continue
        
        # Download and merge
        step_dir = os.path.join(cache_dir, f"global_step_{step:06d}")
        raw_dir = download_hf_checkpoint(repo_id, commit_id, cache_dir, step, subdir=subdir)
        merged_dir = os.path.join(step_dir, "actor_merged")
        
        usable_path = merge_fsdp_checkpoint(
            raw_ckpt_dir=raw_dir,
            merged_dir=merged_dir,
            base_model_path=base_model_path,
            trust_remote_code=trust_remote_code,
        )
        
        checkpoint_paths.append(usable_path)
    
    return sorted(checkpoint_paths, key=extract_step)


# -----------------------------
# vLLM server management
# -----------------------------

@dataclass
class VLLMServerConfig:
    host: str = "127.0.0.1"
    port: int = 30000
    tp: int = 1
    max_model_len: int = 16384  # Needs to be > max_prompt_length + max_response_length
    gpu_mem_util: float = 0.90
    trust_remote_code: bool = True
    dtype: str = "auto"
    enforce_eager: bool = False
    extra_args: str = ""  # raw string appended to vLLM server cmd


def start_vllm_server(model_path: str, cfg: VLLMServerConfig, log_path: str) -> subprocess.Popen:
    cmd = [
        sys.executable, "-m", "vllm.entrypoints.openai.api_server",
        "--model", model_path,
        "--host", cfg.host,
        "--port", str(cfg.port),
        "--tensor-parallel-size", str(cfg.tp),
        "--max-model-len", str(cfg.max_model_len),
        "--gpu-memory-utilization", str(cfg.gpu_mem_util),
        "--dtype", cfg.dtype,
    ]
    if cfg.trust_remote_code:
        cmd.append("--trust-remote-code")
    if cfg.enforce_eager:
        cmd.append("--enforce-eager")
    if cfg.extra_args.strip():
        cmd.extend(shlex.split(cfg.extra_args.strip()))

    ensure_dir(os.path.dirname(log_path))
    lf = open(log_path, "w")

    return subprocess.Popen(
        cmd,
        stdout=lf,
        stderr=subprocess.STDOUT,
        env=os.environ.copy(),
        preexec_fn=os.setsid,
    )


# -----------------------------
# run_generator_fixer_flow runner
# -----------------------------

def parse_val_datasets(val_datasets_str: str) -> List[str]:
    """
    Accept a single string like:
      "bugbench_human:test_small bigcodebench:test_small"
    and return:
      ["bugbench_human:test_small", "bigcodebench:test_small"]
    """
    s = (val_datasets_str or "").strip()
    if not s:
        return []
    # use shlex to respect quoting if user does something fancy
    return shlex.split(s)


def build_flow_cmd(
    *,
    flow_py: str,
    model_path: str,
    base_url: str,
    out_dir: str,
    val_datasets: List[str],
    dataset: Optional[str],
    split: str,
    n_tasks: int,
    n_parallel: int,
    n_repeats: int,
    temperature: float,
    top_p: float,
    fixer_attempts_val: int,
    include_failed_test_output: bool,
    extra_flow_args: str,
    save_results: bool,
    print_samples: int,
    eval_pregenerated_only: bool,
) -> List[str]:
    cmd = [
        sys.executable, flow_py,
        "--n_tasks", str(n_tasks),
        "--n_repeats", str(n_repeats),
        "--n_parallel", str(n_parallel),
        "--model", model_path,
        "--base_url", base_url,
        "--temperature", str(temperature),
        "--top_p", str(top_p),
        "--fixer_attempts_val", str(fixer_attempts_val),
        "--output_dir", out_dir,
        "--print_samples", str(print_samples),
    ]

    if eval_pregenerated_only:
        cmd.append("--eval_pregenerated_only")

    if val_datasets:
        cmd.append("--val_datasets")
        cmd.extend(val_datasets)  # nargs="+"
    else:
        if not dataset:
            raise ValueError("Either provide --val_datasets or provide --dataset.")
        cmd.extend(["--dataset", dataset, "--split", split])

    if save_results:
        cmd.append("--save_results")

    cmd.append("--include_failed_test_output" if include_failed_test_output else "--no_failed_test_output")

    if extra_flow_args.strip():
        cmd.extend(shlex.split(extra_flow_args.strip()))

    return cmd


# -----------------------------
# Main driver
# -----------------------------

def main() -> None:
    ap = argparse.ArgumentParser()

    # Checkpoint source (choose one)
    source_group = ap.add_mutually_exclusive_group()
    source_group.add_argument(
        "--parent_ckpt_dir", 
        type=str, 
        default=None,
        help="Local directory containing checkpoints (local mode)"
    )
    source_group.add_argument(
        "--hf_repo", 
        type=str, 
        default=None,
        help="HuggingFace repository ID (e.g., anonymous/selfplay_large_ckpts)"
    )
    source_group.add_argument(
        "--single_ckpt",
        type=str,
        default=None,
        help="Path to a single checkpoint to evaluate (skips discovery)"
    )
    
    # HuggingFace mode options
    ap.add_argument(
        "--base_model", 
        type=str, 
        default="Qwen/Qwen2.5-Coder-7B-Instruct",
        help="Base model path for merging FSDP checkpoints (HF mode)"
    )
    ap.add_argument(
        "--ckpt_cache_dir", 
        type=str, 
        default="/data/user/hf_ckpts",
        help="Directory to cache downloaded and merged checkpoints"
    )
    
    ap.add_argument("--ckpt_subdir", type=str, default="actor_merged")

    ap.add_argument("--outdir", type=str, required=True)
    ap.add_argument("--run_generator_fixer_flow_py", type=str, required=True)

    # vLLM server
    ap.add_argument("--host", type=str, default="127.0.0.1")
    ap.add_argument("--port", type=int, default=31000)
    ap.add_argument("--port_wait_timeout_s", type=float, default=600.0)
    ap.add_argument("--tp", type=int, default=1)
    ap.add_argument("--max_model_len", type=int, default=16384)
    ap.add_argument("--gpu_mem_util", type=float, default=0.90)
    ap.add_argument("--dtype", type=str, default="auto")
    ap.add_argument("--enforce_eager", action="store_true")
    ap.add_argument("--vllm_extra_args", type=str, default="")

    # Flow eval settings
    ap.add_argument(
        "--val_datasets",
        type=str,
        default="",
        help='Space-separated list, e.g. "bugbench_human:test_small bigcodebench:test_small". '
             "If provided, uses multi-dataset mode in run_generator_fixer_flow.py.",
    )
    ap.add_argument("--dataset", type=str, default=None, help="Single-dataset fallback if --val_datasets is empty.")
    ap.add_argument("--split", type=str, default="test_small")

    ap.add_argument("--n_tasks", type=int, default=256)
    ap.add_argument("--n_parallel", type=int, default=32)
    ap.add_argument("--n_repeats", type=int, default=1)
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--top_p", type=float, default=1.0)
    ap.add_argument("--fixer_attempts_val", type=int, default=1)
    ap.add_argument("--include_failed_test_output", action="store_true", default=True)
    ap.add_argument("--no_failed_test_output", action="store_false", dest="include_failed_test_output")
    ap.add_argument("--extra_flow_args", type=str, default="")
    ap.add_argument("--save_results", action="store_true", default=True)
    ap.add_argument("--no_save_results", action="store_false", dest="save_results")
    ap.add_argument("--print_samples", type=int, default=0)

    ap.add_argument("--eval_pregenerated_only", action="store_true", default=True)
    ap.add_argument("--no_eval_pregenerated_only", action="store_false", dest="eval_pregenerated_only")

    # Filtering / selection
    ap.add_argument("--steps", type=str, default="")
    ap.add_argument("--min_step", type=int, default=-1)
    ap.add_argument("--max_step", type=int, default=-1)
    
    # Skip options
    ap.add_argument("--download_only", action="store_true", help="Only download and merge, skip evaluation")
    ap.add_argument("--skip_merge", action="store_true", help="Skip merging, assume merged checkpoints exist")
    ap.add_argument("--kill_existing", action="store_true", help="Kill any existing process using the port before starting")

    args = ap.parse_args()
    
    # Validate args
    if not args.parent_ckpt_dir and not args.hf_repo and not args.single_ckpt:
        ap.error("One of --parent_ckpt_dir, --hf_repo, or --single_ckpt is required")
    
    ensure_dir(args.outdir)

    repo_root = find_repo_root(os.path.dirname(os.path.abspath(__file__)))
    subproc_env = env_with_repo_on_pythonpath(repo_root)

    # Parse step filters
    wanted: Optional[Set[int]] = None
    if args.steps.strip():
        wanted = set(int(x.strip()) for x in args.steps.split(",") if x.strip())

    # Get checkpoint paths
    if args.single_ckpt:
        # Single checkpoint mode: just use the provided path
        if not os.path.exists(args.single_ckpt):
            ap.error(f"Checkpoint path does not exist: {args.single_ckpt}")
        ckpts = [args.single_ckpt]
        print(f"\n=== Single checkpoint mode: {args.single_ckpt} ===")
    elif args.hf_repo:
        # HuggingFace mode: download and merge checkpoints
        ckpts = get_hf_checkpoints(
            repo_id=args.hf_repo,
            cache_dir=args.ckpt_cache_dir,
            base_model_path=args.base_model,
            steps=wanted,
            min_step=args.min_step,
            max_step=args.max_step,
            trust_remote_code=True,
        )
        
        if args.download_only:
            print(f"\n=== Download complete. Merged checkpoints at: {args.ckpt_cache_dir} ===")
            return
    else:
        # Local mode: find existing checkpoints
        ckpts = find_ckpts(args.parent_ckpt_dir, args.ckpt_subdir)
        
        # Apply filters
        filtered: List[str] = []
        for c in ckpts:
            s = extract_step(c)
            if wanted is not None and s not in wanted:
                continue
            if args.min_step >= 0 and s < args.min_step:
                continue
            if args.max_step >= 0 and s > args.max_step:
                continue
            filtered.append(c)
        
        ckpts = filtered

    if not ckpts:
        raise RuntimeError("No checkpoints left after filtering.")

    val_datasets_list = parse_val_datasets(args.val_datasets)

    if not val_datasets_list and not args.dataset:
        raise ValueError('Provide either --val_datasets "ds1:split ds2:split" or --dataset for single-dataset mode.')

    summary_rows: List[Dict[str, object]] = []

    for ckpt in ckpts:
        step = extract_step(ckpt)
        step_dir = os.path.join(args.outdir, f"step_{step:06d}")
        ensure_dir(step_dir)

        server_log = os.path.join(step_dir, "vllm_server.log")
        flow_log = os.path.join(step_dir, "flow_stdout.log")
        meta_log = os.path.join(step_dir, "meta.json")

        cfg = VLLMServerConfig(
            host=args.host,
            port=args.port,
            tp=args.tp,
            max_model_len=args.max_model_len,
            gpu_mem_util=args.gpu_mem_util,
            trust_remote_code=True,
            dtype=args.dtype,
            enforce_eager=bool(args.enforce_eager),
            extra_args=args.vllm_extra_args,
        )

        if is_port_open(cfg.host, cfg.port):
            if args.kill_existing:
                print(f"  [warn] Port {cfg.port} in use, attempting to kill existing process...")
                kill_process_on_port(cfg.port)
                time.sleep(2)  # Wait for port to be released
                if is_port_open(cfg.host, cfg.port):
                    raise RuntimeError(
                        f"Port {cfg.port} still in use after kill attempt. "
                        f"Please manually stop the existing server."
                    )
            else:
                raise RuntimeError(
                    f"Port {cfg.port} already in use on {cfg.host}. "
                    f"Pick a different --port, stop the existing server, or use --kill_existing."
                )

        print(f"\n=== step={step} ckpt={ckpt} ===")
        print(f"[1/3] Starting vLLM server on http://{cfg.host}:{cfg.port} ...")
        p = start_vllm_server(ckpt, cfg, server_log)

        ok = wait_port(cfg.host, cfg.port, timeout_s=float(args.port_wait_timeout_s))
        if not ok:
            kill_process_tree(p)
            raise RuntimeError(
                f"vLLM server did not open port {cfg.port} within {args.port_wait_timeout_s}s. "
                f"See log: {server_log}"
            )

        base_url = f"http://{cfg.host}:{cfg.port}/v1"

        flow_cmd = build_flow_cmd(
            flow_py=args.run_generator_fixer_flow_py,
            model_path=ckpt,
            base_url=base_url,
            out_dir=step_dir,
            val_datasets=val_datasets_list,
            dataset=args.dataset,
            split=args.split,
            n_tasks=args.n_tasks,
            n_parallel=args.n_parallel,
            n_repeats=args.n_repeats,
            temperature=args.temperature,
            top_p=args.top_p,
            fixer_attempts_val=args.fixer_attempts_val,
            include_failed_test_output=args.include_failed_test_output,
            extra_flow_args=args.extra_flow_args,
            save_results=args.save_results,
            print_samples=args.print_samples,
            eval_pregenerated_only=args.eval_pregenerated_only,
        )
        flow_cmd_str = " ".join(shlex.quote(x) for x in flow_cmd)

        print(f"[2/3] Running run_generator_fixer_flow.py ...")
        t0 = time.time()
        with open(flow_log, "w") as lf:
            fproc = subprocess.Popen(
                flow_cmd,
                stdout=lf,
                stderr=subprocess.STDOUT,
                env=subproc_env,
                cwd=repo_root,   # <-- critical
            )
            rc = fproc.wait()

        dt = time.time() - t0

        print(f"[3/3] Stopping vLLM server ...")
        kill_process_tree(p)
        
        # Wait for port to be released before continuing to next checkpoint
        wait_start = time.time()
        while is_port_open(cfg.host, cfg.port) and (time.time() - wait_start) < 30:
            time.sleep(0.5)
        if is_port_open(cfg.host, cfg.port):
            print(f"  [warn] Port {cfg.port} still in use after 30s, waiting longer...")
            time.sleep(5)

        row = {
            "step": step,
            "ckpt": ckpt,
            "base_url": base_url,
            "returncode": int(rc),
            "wall_time_sec": float(dt),
            "server_log": server_log,
            "flow_log": flow_log,
            "flow_cmd": flow_cmd_str,
            "val_datasets": val_datasets_list,
            "single_dataset_fallback": {"dataset": args.dataset, "split": args.split},
            "eval_pregenerated_only": bool(args.eval_pregenerated_only),
        }
        
        # Add HF repo info if applicable
        if args.hf_repo:
            row["hf_repo"] = args.hf_repo
            row["base_model"] = args.base_model
        
        summary_rows.append(row)

        with open(meta_log, "w") as f:
            json.dump(row, f, indent=2)

        if rc != 0:
            print(f"[WARN] step={step} flow failed (rc={rc}). See {flow_log}")
        else:
            print(f"[OK] step={step} done in {dt/60:.1f} min. Outputs in {step_dir}")

    summary_path = os.path.join(args.outdir, "summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary_rows, f, indent=2)
    print(f"\nSaved summary -> {summary_path}")


if __name__ == "__main__":
    main()
