#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
import os
import re
import threading
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple

from openai import OpenAI


TRAIN_STEP_RE = re.compile(r"step_(\d+)_traindata\.jsonl$")
NUMBERED_JSONL_RE = re.compile(r"(\d+)(?:_[^.]+)?\.jsonl$")
CHATML_ANALYSIS = "<|channel|>analysis<|message|>"
CHATML_FINAL = "<|end|><|start|>assistant<|channel|>final<|message|>"
END_OF_TEXT = "<|endoftext|>"
CORE_BEHAVIORS = ["Backtracking", "Verification",
                  "Subgoal Setting", "Enumeration"]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="用 LLM 抽取 train/valid rollouts 的 reasoning behaviors，并按 step/run 聚合统计。"
    )
    parser.add_argument(
        "--action",
        type=str,
        default="all",
        choices=["generate", "summarize", "all"],
        help="generate: 只生成标注; summarize: 只汇总已有标注; all: 先生成再汇总。",
    )
    parser.add_argument(
        "runs",
        nargs="*",
        help=(
            "数据目录列表，支持 training_data 或 valid_* 目录；"
            "格式支持 /path/to/dir 或 label=/path/to/dir"
        ),
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path(__file__).resolve().parent /
        "results" / "reasoning_behaviors",
        help="输出目录。",
    )
    parser.add_argument(
        "--details-file",
        type=Path,
        default=None,
        help="rollout 级别标注结果 jsonl。默认: output-dir/behavior_details.jsonl",
    )
    parser.add_argument(
        "--run-summary-file",
        type=Path,
        default=None,
        help="run 级别汇总 CSV。默认: output-dir/run_summary.csv",
    )
    parser.add_argument(
        "--step-summary-file",
        type=Path,
        default=None,
        help="step 级别汇总 CSV。默认: output-dir/step_summary.csv",
    )
    parser.add_argument(
        "--summary-json-file",
        type=Path,
        default=None,
        help="汇总 JSON。默认: output-dir/summary.json",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=5,
        help="并行 worker 数量。",
    )
    parser.add_argument(
        "--delay",
        type=float,
        default=2.0,
        help="请求之间的基础延迟秒数。",
    )
    parser.add_argument(
        "--max-retries",
        type=int,
        default=5,
        help="单条请求最大重试次数。",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="claude-3-7-sonnet-20250219-thinking",
        help="用于行为抽取的模型名。",
    )
    parser.add_argument(
        "--api-key",
        type=str,
        default=os.environ.get("OPENAI_API_KEY", ""),
        help="API key，默认读取环境变量 OPENAI_API_KEY。",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=os.environ.get("OPENAI_BASE_URL", ""),
        help="API base url，默认读取环境变量 OPENAI_BASE_URL。",
    )
    parser.add_argument(
        "--start-step",
        type=int,
        default=None,
        help="可选：只分析 >= start-step 的文件。",
    )
    parser.add_argument(
        "--end-step",
        type=int,
        default=None,
        help="可选：只分析 <= end-step 的文件。",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=None,
        help="可选：每个 run 最多分析多少个 step 文件。",
    )
    parser.add_argument(
        "--step-stride",
        type=int,
        default=1,
        help="step 采样间隔。比如 50 表示每隔 50 个 step 取一次，默认 1。",
    )
    parser.add_argument(
        "--max-items-per-step",
        type=int,
        default=None,
        help="每个 train step 最多分析多少条 item。比如 32 表示每个 step 只取前 32 条。",
    )
    parser.add_argument(
        "--max-items-per-valid-step",
        type=int,
        default=100,
        help="每个 valid step 最多分析多少条 item。默认按顺序取前 100 条。",
    )
    parser.add_argument(
        "--response-key",
        type=str,
        default="responses",
        help="response 列表字段名，默认 responses。",
    )
    parser.add_argument(
        "--accuracy-key",
        type=str,
        default="accuracies",
        help="accuracy 列表字段名，默认 accuracies。",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="强制重跑所有 rollout 标注。",
    )
    return parser.parse_args()


def ensure_output_paths(args: argparse.Namespace) -> Dict[str, Path]:
    output_dir = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)
    return {
        "output_dir": output_dir,
        "details_file": args.details_file or output_dir / "behavior_details.jsonl",
        "run_summary_file": args.run_summary_file or output_dir / "run_summary.csv",
        "step_summary_file": args.step_summary_file or output_dir / "step_summary.csv",
        "summary_json_file": args.summary_json_file or output_dir / "summary.json",
    }


def parse_run_spec(spec: str) -> Tuple[str, Path]:
    label = ""
    path_str = spec

    if "=" in spec:
        label, path_str = spec.split("=", 1)
    elif ":" in spec:
        maybe_label, maybe_path = spec.split(":", 1)
        if "/" not in maybe_label and maybe_path:
            label, path_str = maybe_label, maybe_path

    training_data_dir = Path(path_str).expanduser()
    if not label:
        label = infer_label_from_path(training_data_dir)
    return label.strip(), training_data_dir


def infer_label_from_path(path: Path) -> str:
    if path.name == "training_data" and path.parent.name:
        return path.parent.name
    if path.parent.name and path.name.lower().startswith("valid"):
        return f"{path.parent.name}_{path.name}"
    if path.name:
        return path.name
    return str(path)


def iter_step_files(training_data_dir: Path) -> Tuple[str, List[Tuple[int, Path]]]:
    step_files: List[Tuple[int, Path]] = []
    for path in training_data_dir.glob("step_*_traindata.jsonl"):
        match = TRAIN_STEP_RE.search(path.name)
        if match:
            step_files.append((int(match.group(1)), path))
    if step_files:
        step_files.sort(key=lambda x: x[0])
        return "train", step_files

    for path in training_data_dir.glob("*.jsonl"):
        match = NUMBERED_JSONL_RE.fullmatch(path.name)
        if match:
            step_files.append((int(match.group(1)), path))
    step_files.sort(key=lambda x: x[0])
    if step_files:
        return "valid", step_files
    return "unknown", []


def normalize_text(value: Any) -> str:
    if value is None:
        return ""
    return str(value).replace(END_OF_TEXT, "").replace("\r\n", "\n").strip()


def to_bool(value: Any) -> bool | None:
    if value is None:
        return None
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return bool(value)
    if isinstance(value, str):
        lowered = value.strip().lower()
        if lowered in {"true", "1", "yes", "y"}:
            return True
        if lowered in {"false", "0", "no", "n"}:
            return False
    return None


def normalize_accuracy_list(values: Any, expected_length: int) -> List[bool | None]:
    if isinstance(values, list):
        normalized = [to_bool(v) for v in values]
    elif values is None:
        normalized = []
    else:
        normalized = [to_bool(values)]

    if len(normalized) < expected_length:
        normalized.extend([None] * (expected_length - len(normalized)))
    return normalized[:expected_length]


def extract_accuracies(
    item: Dict[str, Any],
    accuracy_key: str,
    expected_length: int,
) -> List[bool | None]:
    accuracy_values = item.get(accuracy_key)
    if accuracy_values is None:
        for fallback_key in ("score", "accuracy", "correct", "is_correct"):
            if fallback_key in item and item[fallback_key] is not None:
                accuracy_values = item[fallback_key]
                break
    return normalize_accuracy_list(accuracy_values, expected_length)


def extract_responses(item: Dict[str, Any], response_key: str) -> List[str]:
    responses = item.get(response_key)
    if isinstance(responses, list):
        return [normalize_text(resp) for resp in responses]
    if responses is not None:
        return [normalize_text(responses)]

    for fallback_key in ("response", "output"):
        if fallback_key in item and item[fallback_key] is not None:
            return [normalize_text(item[fallback_key])]
    return []


def extract_prompt(item: Dict[str, Any]) -> str:
    for key in ("input", "prompt", "question"):
        if key in item and item[key] is not None:
            return normalize_text(item[key])
    return ""


def extract_answer(item: Dict[str, Any]) -> str:
    for key in ("ground_truth", "answer", "target"):
        if key in item and item[key] is not None:
            return normalize_text(item[key])
    return ""


def extract_reasoning_text(response: Any) -> str:
    text = normalize_text(response)
    if not text:
        return ""

    if CHATML_ANALYSIS in text:
        after = text.split(CHATML_ANALYSIS, 1)[1]
        if CHATML_FINAL in after:
            return after.split(CHATML_FINAL, 1)[0].strip()
        return after.strip()

    think_spans = re.findall(r"<think>(.*?)</think>",
                             text, flags=re.IGNORECASE | re.DOTALL)
    if think_spans:
        return "\n\n".join(span.strip() for span in think_spans if span.strip())

    if "</think>" in text.lower():
        text = re.split(r"</think>", text, maxsplit=1, flags=re.IGNORECASE)[0]
        text = re.sub(r"^.*?<think>", "", text,
                      flags=re.IGNORECASE | re.DOTALL)
        return text.strip()

    if "<answer>" in text.lower():
        text = re.split(r"<answer>", text, maxsplit=1, flags=re.IGNORECASE)[0]

    return text.strip()


def format_reasoning_behavior_prompt(cot_content: str) -> str:
    return f"""Below is a chain-of-reasoning generated by a Language Model when attempting to solve a math problem. Evaluate this chain-of-reasoning to determine whether it demonstrates beneficial problem-solving behaviors that deviate from typical linear, monotonic reasoning patterns commonly observed in language models.

<start_of_reasoning>
{cot_content}
<end_of_reasoning>

Please identify and emphasize the following beneficial behaviors:

**(1) Backtracking:** Explicitly revising approaches upon identifying errors or dead ends (e.g., "This approach won't work because...").

**(2) Verification:** Systematically checking intermediate results or reasoning steps (e.g., "Let's verify this result by...").

**(3) Subgoal Setting:** Breaking down complex problems into smaller, manageable steps (e.g., "To solve this, we first need to...").

**(4) Enumeration:** Solving problems by exhaustively considering multiple cases or possibilities.

You are also encouraged to identify other beneficial behaviors not explicitly listed, such as creative analogies, abstraction to simpler cases, or insightful generalizations.

Important:
- Clearly specify each beneficial behavior identified.
- Provide explicit examples from the reasoning chain.
- If no beneficial behaviors are observed, explicitly return an empty list.

Provide your evaluation clearly, formatted as follows:
```json
{{
"behaviour": "",
"example": ""
}}
```

If multiple behaviors are found, return a JSON array:
```json
[
  {{
    "behaviour": "",
    "example": ""
  }},
  {{
    "behaviour": "",
    "example": ""
  }}
]
```"""


def extract_json_from_text(text: str) -> Any:
    json_pattern = r"```json\s*\n(.*?)\n```"
    matches = re.findall(json_pattern, text, re.DOTALL)

    if matches:
        try:
            return json.loads(matches[0])
        except json.JSONDecodeError:
            pass

    start_idx = text.find("{")
    if start_idx == -1:
        start_idx = text.find("[")

    if start_idx != -1:
        bracket = text[start_idx]
        end_bracket = "}" if bracket == "{" else "]"

        depth = 0
        for i in range(start_idx, len(text)):
            if text[i] == bracket:
                depth += 1
            elif text[i] == end_bracket:
                depth -= 1
                if depth == 0:
                    try:
                        return json.loads(text[start_idx: i + 1])
                    except json.JSONDecodeError:
                        break
    return None


def canonicalize_behavior_name(name: Any) -> str:
    text = normalize_text(name)
    if not text:
        return ""

    lowered = re.sub(r"[\s_\-]+", " ", text).strip().lower()
    core_mapping = {
        "backtracking": "Backtracking",
        "verification": "Verification",
        "subgoal setting": "Subgoal Setting",
        "subgoal": "Subgoal Setting",
        "enumeration": "Enumeration",
    }
    if lowered in core_mapping:
        return core_mapping[lowered]
    return re.sub(r"\s+", " ", text)


def normalize_behaviors(payload: Any) -> List[Dict[str, str]]:
    if payload is None:
        return []

    if isinstance(payload, dict):
        payload = [payload]

    if not isinstance(payload, list):
        return []

    normalized: List[Dict[str, str]] = []
    for item in payload:
        if not isinstance(item, dict):
            continue
        behavior_name = canonicalize_behavior_name(item.get("behaviour", ""))
        if not behavior_name:
            continue
        normalized.append(
            {
                "behaviour": behavior_name,
                "example": normalize_text(item.get("example", "")),
            }
        )
    return normalized


def extract_behavior_names(record: Dict[str, Any]) -> List[str]:
    behaviors = record.get("behaviors", [])
    names: List[str] = []
    seen = set()
    if isinstance(behaviors, list):
        for item in behaviors:
            if not isinstance(item, dict):
                continue
            name = canonicalize_behavior_name(item.get("behaviour", ""))
            if name and name not in seen:
                seen.add(name)
                names.append(name)
    return names


def build_rollout_tasks(
    runs: List[Tuple[str, Path]],
    start_step: int | None,
    end_step: int | None,
    max_steps: int | None,
    step_stride: int,
    max_items_per_step: int | None,
    max_items_per_valid_step: int | None,
    response_key: str,
    accuracy_key: str,
) -> List[Dict[str, Any]]:
    tasks: List[Dict[str, Any]] = []

    for run_label, training_data_dir in runs:
        if not training_data_dir.exists():
            raise FileNotFoundError(f"目录不存在: {training_data_dir}")
        if not training_data_dir.is_dir():
            raise NotADirectoryError(f"不是目录: {training_data_dir}")

        data_source, step_files = iter_step_files(training_data_dir)
        if start_step is not None or end_step is not None:
            step_files = [
                (step, file_path)
                for step, file_path in step_files
                if (start_step is None or step >= start_step)
                and (end_step is None or step <= end_step)
            ]
        if step_stride > 1 and step_files:
            anchor_step = step_files[0][0]
            step_files = [
                (step, file_path)
                for step, file_path in step_files
                if (step - anchor_step) % step_stride == 0
            ]
        if max_steps is not None and max_steps > 0:
            step_files = step_files[:max_steps]

        if not step_files:
            raise FileNotFoundError(
                f"在目录 {training_data_dir} 中没有找到可分析的 "
                "step_*_traindata.jsonl 或 {step}_*.jsonl 文件。"
            )

        print(f"\n{'=' * 80}")
        print(f"扫描 run: {run_label}")
        print(f"data_source: {data_source}")
        print(f"data_dir: {training_data_dir}")
        print(f"step 文件数: {len(step_files)}")

        for step, file_path in step_files:
            print(f"  - 读取 step {step}: {file_path.name}")
            selected_item_count = 0
            selected_rollout_count = 0
            item_limit = max_items_per_step
            if data_source == "valid":
                item_limit = max_items_per_valid_step
            with file_path.open("r", encoding="utf-8") as f:
                for line_number, line in enumerate(f, start=1):
                    stripped = line.strip()
                    if not stripped:
                        continue
                    try:
                        item = json.loads(stripped)
                    except json.JSONDecodeError:
                        continue

                    responses = extract_responses(
                        item, response_key=response_key)
                    if not responses:
                        continue

                    if item_limit is not None and selected_item_count >= item_limit:
                        break

                    selected_item_count += 1
                    accuracies = extract_accuracies(
                        item=item,
                        accuracy_key=accuracy_key,
                        expected_length=len(responses),
                    )
                    uid = normalize_text(item.get("uid", ""))
                    prompt = extract_prompt(item)
                    answer = extract_answer(item)

                    for rollout_index, response in enumerate(responses):
                        reasoning_text = extract_reasoning_text(response)
                        rollout_id = f"{run_label}__step_{step}__line_{line_number}__rollout_{rollout_index}"
                        if uid:
                            rollout_id = f"{run_label}__step_{step}__uid_{uid}__rollout_{rollout_index}"

                        tasks.append(
                            {
                                "rollout_id": rollout_id,
                                "run": run_label,
                                "training_data_dir": str(training_data_dir),
                                "data_source": data_source,
                                "step": step,
                                "step_file": file_path.name,
                                "line_number": line_number,
                                "uid": uid,
                                "rollout_index": rollout_index,
                                "correctness": accuracies[rollout_index],
                                "prompt": prompt,
                                "answer": answer,
                                "response": response,
                                "reasoning_text": reasoning_text if reasoning_text else response,
                            }
                        )
                        selected_rollout_count += 1
            print(
                f"    选中 item 数: {selected_item_count}, "
                f"实际 rollout 数: {selected_rollout_count}"
            )
    return tasks


def load_existing_annotations(details_file: Path) -> Dict[str, Dict[str, Any]]:
    if not details_file.exists():
        return {}

    existing: Dict[str, Dict[str, Any]] = {}
    with details_file.open("r", encoding="utf-8") as f:
        for line in f:
            stripped = line.strip()
            if not stripped:
                continue
            try:
                record = json.loads(stripped)
            except json.JSONDecodeError:
                continue
            rollout_id = normalize_text(record.get("rollout_id", ""))
            if rollout_id:
                existing[rollout_id] = record
    return existing


def create_client(api_key: str, base_url: str) -> OpenAI:
    if not api_key:
        raise ValueError(
            "缺少 API key。请传 --api-key 或设置环境变量 OPENAI_API_KEY。"
        )
    if base_url:
        return OpenAI(api_key=api_key, base_url=base_url)
    return OpenAI(api_key=api_key)


def chat_once(
    client: OpenAI,
    model: str,
    messages: List[Dict[str, str]],
    max_retries: int,
) -> str | None:
    for attempt in range(max_retries):
        try:
            completion = client.chat.completions.create(
                model=model,
                messages=messages,
            )
            return completion.choices[0].message.content or ""
        except Exception as e:
            wait_time = 2 ** attempt
            print(f"❌ API 调用失败 (第{attempt + 1}/{max_retries}次): {e}")
            if attempt < max_retries - 1:
                print(f"⏳ 等待 {wait_time} 秒后重试...")
                time.sleep(wait_time)
            else:
                print("💥 达到最大重试次数，跳过该 rollout")
                return None
    return None


def append_jsonl(path: Path, record: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")


def rewrite_jsonl(path: Path, records: List[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for record in records:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")


def should_regenerate(existing_record: Dict[str, Any] | None, force: bool) -> bool:
    if force or existing_record is None:
        return True
    if existing_record.get("parse_ok") is False:
        return True
    if existing_record.get("raw_output_only") is True:
        return True
    return False


def generate_annotations(
    tasks: List[Dict[str, Any]],
    details_file: Path,
    client: OpenAI,
    model: str,
    max_workers: int,
    delay_between_requests: float,
    max_retries: int,
    force: bool,
) -> List[Dict[str, Any]]:
    if force and details_file.exists():
        details_file.unlink()

    existing = load_existing_annotations(details_file)
    pending_tasks = [
        task for task in tasks if should_regenerate(existing.get(task["rollout_id"]), force=force)
    ]

    print(f"\n{'=' * 80}")
    print("生成 rollout behaviors")
    print(f"总 rollout 数: {len(tasks)}")
    print(f"已存在可复用结果: {len(tasks) - len(pending_tasks)}")
    print(f"待生成: {len(pending_tasks)}")

    if not pending_tasks:
        print("✅ 没有需要重新生成的 rollout")
        merged = list(existing.values())
        merged.sort(key=lambda x: x.get("rollout_id", ""))
        return merged

    lock = threading.Lock()

    def process_single_task(idx: int, task: Dict[str, Any]) -> Dict[str, Any]:
        if delay_between_requests > 0:
            time.sleep(delay_between_requests * idx / max(max_workers, 1))

        prompt = format_reasoning_behavior_prompt(task["reasoning_text"])
        messages = [{"role": "user", "content": prompt}]
        output = chat_once(
            client=client,
            model=model,
            messages=messages,
            max_retries=max_retries,
        )

        result = {
            **task,
            "model": model,
            "annotation_prompt": prompt,
            "annotation_output": output,
            "parse_ok": False,
            "raw_output_only": False,
            "behaviors": [],
        }

        if output is None:
            result["raw_output_only"] = True
            return result

        parsed = extract_json_from_text(output)
        behaviors = normalize_behaviors(parsed)
        if parsed is None:
            result["raw_output_only"] = True
        else:
            result["parse_ok"] = True
            result["behaviors"] = behaviors
        return result

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {
            executor.submit(process_single_task, idx, task): task
            for idx, task in enumerate(pending_tasks)
        }

        completed = 0
        for future in as_completed(future_to_task):
            result = future.result()
            with lock:
                append_jsonl(details_file, result)
                existing[result["rollout_id"]] = result
                completed += 1
            print(
                f"✅ {result['rollout_id']} "
                f"({completed}/{len(pending_tasks)}) "
                f"parse_ok={result['parse_ok']}"
            )

    merged = list(existing.values())
    merged.sort(key=lambda x: x.get("rollout_id", ""))
    rewrite_jsonl(details_file, merged)
    return merged


def aggregate_records(records: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]], List[Dict[str, Any]]]:
    all_behaviors = set()
    run_stats: Dict[str, Dict[str, Any]] = {}
    step_stats: Dict[Tuple[str, int], Dict[str, Any]] = {}

    def init_group(base_record: Dict[str, Any], include_step: bool) -> Dict[str, Any]:
        group = {
            "run": base_record["run"],
            "training_data_dir": base_record["training_data_dir"],
            "data_source": base_record.get("data_source", ""),
            "step": base_record["step"] if include_step else "",
            "step_file": base_record["step_file"] if include_step else "",
            "rollout_count": 0,
            "correct_rollout_count": 0,
            "incorrect_rollout_count": 0,
            "unknown_accuracy_rollout_count": 0,
            "parse_ok_count": 0,
            "parse_fail_count": 0,
            "rollout_with_any_behavior_count": 0,
            "behavior_total_count": 0,
            "behavior_counts": Counter(),
            "behavior_counts_correct": Counter(),
            "behavior_counts_incorrect": Counter(),
        }
        return group

    def update_group(group: Dict[str, Any], record: Dict[str, Any]) -> None:
        group["rollout_count"] += 1
        correctness = record.get("correctness")
        if correctness is True:
            group["correct_rollout_count"] += 1
        elif correctness is False:
            group["incorrect_rollout_count"] += 1
        else:
            group["unknown_accuracy_rollout_count"] += 1

        if record.get("parse_ok"):
            group["parse_ok_count"] += 1
        else:
            group["parse_fail_count"] += 1

        behavior_names = extract_behavior_names(record)
        if behavior_names:
            group["rollout_with_any_behavior_count"] += 1
        group["behavior_total_count"] += len(behavior_names)

        for name in behavior_names:
            all_behaviors.add(name)
            group["behavior_counts"][name] += 1
            if correctness is True:
                group["behavior_counts_correct"][name] += 1
            elif correctness is False:
                group["behavior_counts_incorrect"][name] += 1

    for record in records:
        run_key = record["run"]
        if run_key not in run_stats:
            run_stats[run_key] = init_group(record, include_step=False)
        update_group(run_stats[run_key], record)

        step_key = (record["run"], int(record["step"]))
        if step_key not in step_stats:
            step_stats[step_key] = init_group(record, include_step=True)
        update_group(step_stats[step_key], record)

    behavior_list = sorted(all_behaviors)

    def finalize_group(group: Dict[str, Any]) -> Dict[str, Any]:
        rollout_count = group["rollout_count"]
        row = {
            "run": group["run"],
            "training_data_dir": group["training_data_dir"],
            "data_source": group.get("data_source", ""),
            "step": group["step"],
            "step_file": group["step_file"],
            "rollout_count": rollout_count,
            "correct_rollout_count": group["correct_rollout_count"],
            "incorrect_rollout_count": group["incorrect_rollout_count"],
            "unknown_accuracy_rollout_count": group["unknown_accuracy_rollout_count"],
            "parse_ok_count": group["parse_ok_count"],
            "parse_fail_count": group["parse_fail_count"],
            "rollout_with_any_behavior_count": group["rollout_with_any_behavior_count"],
            "rollout_with_any_behavior_rate": round(
                group["rollout_with_any_behavior_count"] / rollout_count, 6
            )
            if rollout_count
            else 0.0,
            "behavior_total_count": group["behavior_total_count"],
            "avg_behavior_count_per_rollout": round(
                group["behavior_total_count"] / rollout_count, 6
            )
            if rollout_count
            else 0.0,
            "unique_behavior_types": sum(
                1 for name in behavior_list if group["behavior_counts"].get(name, 0) > 0
            ),
            "core_behavior_total_count": sum(
                group["behavior_counts"].get(name, 0) for name in CORE_BEHAVIORS
            ),
        }

        for behavior in behavior_list:
            key = behavior_to_column_name(behavior)
            row[f"{key}_count"] = group["behavior_counts"].get(behavior, 0)
            row[f"{key}_rate"] = round(
                group["behavior_counts"].get(behavior, 0) / rollout_count, 6
            ) if rollout_count else 0.0
            row[f"{key}_correct_count"] = group["behavior_counts_correct"].get(
                behavior, 0)
            row[f"{key}_incorrect_count"] = group["behavior_counts_incorrect"].get(
                behavior, 0)
        return row

    run_rows = [finalize_group(run_stats[key])
                for key in sorted(run_stats.keys())]
    step_rows = [
        finalize_group(step_stats[key])
        for key in sorted(step_stats.keys(), key=lambda x: (x[0], x[1]))
    ]
    return behavior_list, run_rows, step_rows


def behavior_to_column_name(name: str) -> str:
    lowered = re.sub(r"[^a-zA-Z0-9]+", "_", name.strip().lower()).strip("_")
    return lowered or "unknown_behavior"


def write_csv(path: Path, rows: List[Dict[str, Any]], fieldnames: List[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in rows:
            writer.writerow(row)


def summarize_annotations(
    records: List[Dict[str, Any]],
    paths: Dict[str, Path],
    args: argparse.Namespace,
) -> None:
    if not records:
        raise ValueError("没有可用于汇总的标注结果。")

    behavior_list, run_rows, step_rows = aggregate_records(records)

    common_fields = [
        "run",
        "training_data_dir",
        "data_source",
        "step",
        "step_file",
        "rollout_count",
        "correct_rollout_count",
        "incorrect_rollout_count",
        "unknown_accuracy_rollout_count",
        "parse_ok_count",
        "parse_fail_count",
        "rollout_with_any_behavior_count",
        "rollout_with_any_behavior_rate",
        "behavior_total_count",
        "avg_behavior_count_per_rollout",
        "unique_behavior_types",
        "core_behavior_total_count",
    ]
    behavior_fields: List[str] = []
    for behavior in behavior_list:
        key = behavior_to_column_name(behavior)
        behavior_fields.extend(
            [
                f"{key}_count",
                f"{key}_rate",
                f"{key}_correct_count",
                f"{key}_incorrect_count",
            ]
        )

    step_fieldnames = common_fields + behavior_fields
    run_fieldnames = ["run", "training_data_dir", "data_source"] + \
        common_fields[5:] + behavior_fields

    run_rows_for_csv = []
    for row in run_rows:
        copied = dict(row)
        copied.pop("step", None)
        copied.pop("step_file", None)
        run_rows_for_csv.append(copied)

    write_csv(paths["step_summary_file"], step_rows, step_fieldnames)
    write_csv(paths["run_summary_file"], run_rows_for_csv, run_fieldnames)

    with paths["summary_json_file"].open("w", encoding="utf-8") as f:
        json.dump(
            {
                "generated_at": datetime.now().isoformat(),
                "action": args.action,
                "behavior_names": behavior_list,
                "core_behaviors": CORE_BEHAVIORS,
                "args": {
                    "runs": args.runs,
                    "output_dir": str(paths["output_dir"]),
                    "details_file": str(paths["details_file"]),
                    "model": args.model,
                    "start_step": args.start_step,
                    "end_step": args.end_step,
                    "max_steps": args.max_steps,
                    "step_stride": args.step_stride,
                    "max_items_per_step": args.max_items_per_step,
                    "max_items_per_valid_step": args.max_items_per_valid_step,
                    "response_key": args.response_key,
                    "accuracy_key": args.accuracy_key,
                },
                "run_summary": run_rows,
                "step_summary": step_rows,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    print(f"\n{'=' * 80}")
    print("Run Summary")
    print(
        f"{'run':<24} {'rollouts':>12} {'any':>10} {'parse_ok':>10} {'behaviors':>10}"
    )
    for row in run_rows:
        print(
            f"{row['run']:<24} {row['rollout_count']:>12} "
            f"{row['rollout_with_any_behavior_count']:>10} "
            f"{row['parse_ok_count']:>10} "
            f"{row['behavior_total_count']:>10}"
        )

    print(f"\n结果已保存到: {paths['output_dir']}")
    print(f"  - {paths['details_file'].name}")
    print(f"  - {paths['step_summary_file'].name}")
    print(f"  - {paths['run_summary_file'].name}")
    print(f"  - {paths['summary_json_file'].name}")


def main() -> None:
    args = parse_args()
    paths = ensure_output_paths(args)

    if args.action in {"generate", "all"} and not args.runs:
        raise ValueError(
            "generate/all 模式下必须提供至少一个数据目录（training_data 或 valid_*）。")

    runs = [parse_run_spec(spec) for spec in args.runs]
    labels = [label for label, _ in runs]
    if len(set(labels)) != len(labels):
        raise ValueError(f"run label 重复，请为每个目录设置唯一 label: {labels}")

    records: List[Dict[str, Any]]

    if args.action in {"generate", "all"}:
        tasks = build_rollout_tasks(
            runs=runs,
            start_step=args.start_step,
            end_step=args.end_step,
            max_steps=args.max_steps,
            step_stride=args.step_stride,
            max_items_per_step=args.max_items_per_step,
            max_items_per_valid_step=args.max_items_per_valid_step,
            response_key=args.response_key,
            accuracy_key=args.accuracy_key,
        )
        client = create_client(api_key=args.api_key, base_url=args.base_url)
        records = generate_annotations(
            tasks=tasks,
            details_file=paths["details_file"],
            client=client,
            model=args.model,
            max_workers=args.workers,
            delay_between_requests=args.delay,
            max_retries=args.max_retries,
            force=args.force,
        )
    else:
        if not paths["details_file"].exists():
            raise FileNotFoundError(
                f"找不到 details 文件: {paths['details_file']}，请先运行 generate/all。"
            )
        records = list(load_existing_annotations(
            paths["details_file"]).values())
        records.sort(key=lambda x: x.get("rollout_id", ""))

    if args.action in {"summarize", "all"}:
        summarize_annotations(records=records, paths=paths, args=args)


if __name__ == "__main__":
    main()
