#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import re
import csv
import argparse
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Set
from collections import defaultdict

# 解析第一行：.../xxxxB/.../replace_count_yyyyyy/....json
RE_FIRST = re.compile(
    r"""
    (?P<path>.+?)
    /(?P<model>\d+(?:\.\d+)?B)/
    \d+/
    replace_count_(?P<replace>\d+)/
    (?P<stamp>\d{8}_\d{6})\.json
    """,
    re.VERBOSE
)

# 解析第二行
RE_RATE = re.compile(
    r"""jailbreak\s*success\s*rate\s*:\s*(?P<rate>\d+(?:\.\d+)?)\s*%""",
    re.IGNORECASE
)

# 解析第三行
RE_SUCC_ALL = re.compile(
    r"""success\s*:\s*(?P<succ>\d+)\s+all\s*:\s*(?P<all>\d+)""",
    re.IGNORECASE
)

def model_to_float(model: str) -> float:
    """将 '14B' 或 '1.5B' 转为数值 14.0 / 1.5，用于排序。"""
    return float(model.rstrip("B").strip())

def parse_group(lines3: Tuple[str, str, str]) -> Optional[dict]:
    l1, l2, l3 = (s.strip() for s in lines3)

    m1 = RE_FIRST.search(l1)
    if not m1:
        return None

    model = m1.group("model")
    replace = int(m1.group("replace"))
    path = m1.group("path")
    stamp = m1.group("stamp")

    rate_pct: Optional[float] = None
    m2 = RE_RATE.search(l2)
    if m2:
        rate_pct = float(m2.group("rate"))

    succ = all_ = None
    m3 = RE_SUCC_ALL.search(l3)
    if m3:
        succ = int(m3.group("succ"))
        all_ = int(m3.group("all"))
        if all_ and (rate_pct is None):
            rate_pct = round(100.0 * succ / all_, 3)

    if rate_pct is None and (succ is None or all_ is None):
        return None

    return {
        "model_size": model,
        "replace_count": replace,
        "jailbreak_success_rate_%": rate_pct,
        "success": succ,
        "all": all_,
        "source_path": path,
        "timestamp": stamp,
        "calc_rate_%": (round(100.0 * succ / all_, 3) if (succ is not None and all_) else None),
    }

def read_lines_from_stdin_or_file(infile: Optional[str]) -> List[str]:
    if infile:
        p = Path(infile)
        if not p.exists():
            sys.exit(f"输入文件不存在: {infile}")
        return p.read_text(encoding="utf-8", errors="ignore").splitlines()
    else:
        return sys.stdin.read().splitlines()

def chunk_triplets(lines: List[str]) -> List[Tuple[str, str, str]]:
    triplets = []
    buf = []
    for line in lines:
        if line.strip() == "" and not buf:
            continue
        buf.append(line)
        if len(buf) == 3:
            triplets.append(tuple(buf))  # type: ignore
            buf = []
    return triplets

def main():
    parser = argparse.ArgumentParser(
        description="解析每3行一组的结果，并按 model_size 分组、按 replace_count 升序展示与导出，同时检查缺失的 replace。"
    )
    parser.add_argument("-i", "--input", type=str, default=None, help="输入文件（不填则从标准输入读取）")
    parser.add_argument("-o", "--output", type=str, default="parsed_results.csv", help="输出 CSV 路径")
    parser.add_argument("--no-header", action="store_true", help="CSV 不写入表头")
    parser.add_argument("--replace-min", type=int, default=0, help="期望 replace 最小值（含），默认 0")
    parser.add_argument("--replace-max", type=int, default=25, help="期望 replace 最大值（含），默认 25")
    args = parser.parse_args()

    if args.replace_max < args.replace_min:
        sys.exit("--replace-max 不可小于 --replace-min")

    lines = read_lines_from_stdin_or_file(args.input)
    groups = chunk_triplets(lines)

    rows = []
    for g in groups:
        row = parse_group(g)
        if row:
            rows.append(row)

    if not rows:
        sys.exit("未解析到任何有效数据，请检查输入格式。")

    # === 排序：先按 model_size 数值，再按 replace_count 升序 ===
    rows_sorted = sorted(rows, key=lambda r: (model_to_float(r["model_size"]), r["replace_count"]))

    # === 导出 CSV（平铺，但已排序） ===
    fieldnames = [
        "model_size",
        "replace_count",
        "jailbreak_success_rate_%",
        "success",
        "all",
        "calc_rate_%",
        "source_path",
        "timestamp",
    ]
    with open(args.output, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if not args.no_header:
            writer.writeheader()
        for r in rows_sorted:
            writer.writerow(r)

    # === 终端分组展示：相同 model_size 聚集，组内按 replace_count 升序 ===
    grouped: Dict[str, List[dict]] = defaultdict(list)
    for r in rows_sorted:
        grouped[r["model_size"]].append(r)

    print("=" * 80)
    print("分组展示（相同 model size 聚集，按 replace_count 升序）")
    print("=" * 80)

    for model in sorted(grouped.keys(), key=model_to_float):
        print(f"\n[Model: {model}]")
        print(f"{'replace':>8}  {'rate%':>8}  {'success':>8}  {'all':>8}  {'calc%':>8}  timestamp")
        for r in grouped[model]:
            rep = r['replace_count']
            rate = r['jailbreak_success_rate_%']
            succ = r.get('success')
            all_ = r.get('all')
            calc = r.get('calc_rate_%')
            ts = r.get('timestamp')
            print(f"{rep:8d}  {rate:8.3f}  {str(succ or ''):>8}  {str(all_ or ''):>8}  {str(calc or ''):>8}  {ts}")

    # === 追加：检查缺失的 replace 并打印提示 ===
    print("\n" + "=" * 80)
    print("缺失 replace 检查（默认期望范围：[%d, %d]，可通过参数调整）" % (args.replace_min, args.replace_max))
    print("=" * 80)

    expected: Set[int] = set(range(args.replace_min, args.replace_max + 1))
    any_missing = False
    for model in sorted(grouped.keys(), key=model_to_float):
        present = {r["replace_count"] for r in grouped[model]}
        missing_sorted = sorted(expected - present)
        if missing_sorted:
            any_missing = True
            # 输出格式：xxxB的大小的模型缺乏 yyyyy 这个replace（支持多个，逗号分隔）
            missing_str = ", ".join(str(x) for x in missing_sorted)
            print(f"{model} 的大小的模型缺乏 {missing_str} 这个 replace")
    if not any_missing:
        print("未发现缺失的 replace。")

    print(f"\n解析完成：{len(rows_sorted)} 条记录 -> {args.output}")

if __name__ == "__main__":
    main()
