#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import math
import re
from collections import Counter
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple


SCRIPT_DIR = Path(__file__).resolve().parent

DEFAULT_INPUT_JSON = SCRIPT_DIR / "results" / \
    "baseline-length-analysis_results.json"
DEFAULT_DETAILS_FILE = (
    SCRIPT_DIR / "results" / "reasoning_behaviors_valid" / "behavior_details.jsonl"
)
DEFAULT_SUMMARY_JSON_FILE = (
    SCRIPT_DIR / "results" / "reasoning_behaviors_valid" / "summary.json"
)
DEFAULT_CATEGORY_MAP_FILE = SCRIPT_DIR / "results" / "cluster_behaviors.json"
DEFAULT_CATEGORIZED_STEP_JSON = (
    SCRIPT_DIR / "results" / "reasoning_behaviors_valid" /
    "categorized_step_summary.json"
)

DEFAULT_EXPERIMENTS = {
    "baseline": {
        "aliases": ["baseline"],
    },
    "gspo_length": {
        "aliases": ["gspo_length"],
    },
}


@dataclass(frozen=True)
class ExperimentConfig:
    canonical_name: str
    aliases: Tuple[str, ...]


@dataclass(frozen=True)
class BehaviorCategoryAssignment:
    raw_behavior: str
    category_name: str
    source: str
    rationale: str

    def to_dict(self) -> Dict[str, str]:
        return {
            "raw_behavior": self.raw_behavior,
            "category_name": self.category_name,
            "source": self.source,
            "rationale": self.rationale,
        }


@dataclass(frozen=True)
class CategoryDefinition:
    category_name: str
    description: str

    def to_dict(self) -> Dict[str, str]:
        return {
            "category_name": self.category_name,
            "description": self.description,
        }


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "按固定 reasoning 类别映射统计每个 step 的行为种类数与数量，"
            "并回填到 baseline-length-analysis_results.json。"
        )
    )
    parser.add_argument(
        "--input-json",
        type=Path,
        default=DEFAULT_INPUT_JSON,
        help="待补全的 baseline-length-analysis_results.json 路径。",
    )
    parser.add_argument(
        "--output-json",
        type=Path,
        default=None,
        help="输出 JSON 路径。默认原地覆盖输入文件。",
    )
    parser.add_argument(
        "--details-file",
        type=Path,
        default=DEFAULT_DETAILS_FILE,
        help=(
            "rollout 级别 reasoning behavior 标注 jsonl。"
            "若存在，将优先用它精确聚合。"
        ),
    )
    parser.add_argument(
        "--summary-json-file",
        type=Path,
        default=DEFAULT_SUMMARY_JSON_FILE,
        help=(
            "summary.json 路径。仅在 details 文件不存在时使用，"
            "此时 behavior count 可能因为缺少 rollout 内去重而略高。"
        ),
    )
    parser.add_argument(
        "--category-map-file",
        "--cluster-map-file",
        dest="category_map_file",
        type=Path,
        default=DEFAULT_CATEGORY_MAP_FILE,
        help="固定 reasoning 类别映射 JSON 路径。",
    )
    parser.add_argument(
        "--categorized-step-json",
        "--clustered-step-json",
        dest="categorized_step_json",
        type=Path,
        default=DEFAULT_CATEGORIZED_STEP_JSON,
        help="按固定类别聚合后的 step 汇总 JSON 路径。",
    )
    parser.add_argument(
        "--type-field",
        type=str,
        default="reasoning_behavior_type_count",
        help="写回结果 JSON 的“reasoning behavior 种类数”字段名。",
    )
    parser.add_argument(
        "--count-field",
        type=str,
        default="reasoning_behavior_count",
        help="写回结果 JSON 的“reasoning behavior count”字段名。",
    )
    parser.add_argument(
        "--strict-category-map",
        action="store_true",
        help="若发现 raw behavior 未在类别映射中定义，则直接报错。",
    )
    parser.add_argument(
        "--backup",
        action="store_true",
        help="原地覆盖前先写一个 .bak 备份。",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="只打印匹配情况，不写结果文件。",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="打印更多类别映射与回填细节。",
    )
    return parser.parse_args()


def normalize_name(name: str) -> str:
    return "".join(ch.lower() for ch in name if ch.isalnum())


def normalize_text(value: Any) -> str:
    if value is None:
        return ""
    return str(value).replace("\r\n", "\n").strip()


def normalize_category_name(name: Any) -> str:
    return re.sub(r"\s+", " ", normalize_text(name))


def to_builtin(value: object) -> object:
    if isinstance(value, (int, str, bool)) or value is None:
        return value
    if isinstance(value, float):
        if math.isnan(value) or math.isinf(value):
            return None
        return value
    if hasattr(value, "item"):
        try:
            return to_builtin(value.item())
        except Exception:
            return value
    return value


def build_experiment_configs() -> Dict[str, ExperimentConfig]:
    configs: Dict[str, ExperimentConfig] = {}
    for canonical_name, spec in DEFAULT_EXPERIMENTS.items():
        configs[canonical_name] = ExperimentConfig(
            canonical_name=canonical_name,
            aliases=tuple(spec["aliases"]),
        )
    return configs


def resolve_json_section(
    json_key: str,
    experiment_configs: Mapping[str, ExperimentConfig],
) -> Optional[ExperimentConfig]:
    normalized_key = normalize_name(json_key)
    for config in experiment_configs.values():
        candidates = (config.canonical_name, *config.aliases)
        if normalized_key in {normalize_name(candidate) for candidate in candidates}:
            return config
    return None


def iter_candidate_run_names(config: ExperimentConfig) -> List[str]:
    candidates: List[str] = []
    seen = set()

    for name in (config.canonical_name, *config.aliases):
        normalized = normalize_text(name)
        if normalized and normalized not in seen:
            seen.add(normalized)
            candidates.append(normalized)

    suffixes = ["_valid", "-valid", " valid"]
    derived: List[str] = []
    for name in list(candidates):
        for suffix in suffixes:
            if not name.endswith(suffix):
                derived.append(f"{name}{suffix}")

    for name in derived:
        normalized = normalize_text(name)
        if normalized and normalized not in seen:
            seen.add(normalized)
            candidates.append(normalized)

    return candidates


def load_json(path: Path) -> MutableMapping[str, object]:
    with path.open("r", encoding="utf-8") as handle:
        data = json.load(handle)
    if not isinstance(data, dict):
        raise ValueError(
            f"输入 JSON 顶层必须是 object/dict，实际是 {type(data).__name__}"
        )
    return data


def save_json(path: Path, data: Mapping[str, object]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as handle:
        json.dump(data, handle, indent=2, ensure_ascii=False)
        handle.write("\n")


def canonicalize_behavior_name(name: Any) -> str:
    text = normalize_text(name)
    if not text:
        return ""

    lowered = re.sub(r"[\s_\-]+", " ", text).strip().lower()
    core_mapping = {
        "backtracking": "Backtracking",
        "verification": "Verification",
        "subgoal setting": "Subgoal Setting",
        "subgoal": "Subgoal Setting",
        "enumeration": "Enumeration",
    }
    if lowered in core_mapping:
        return core_mapping[lowered]
    return re.sub(r"\s+", " ", text)


def extract_behavior_names(record: Mapping[str, Any]) -> List[str]:
    behaviors = record.get("behaviors", [])
    names: List[str] = []
    seen = set()
    if isinstance(behaviors, list):
        for item in behaviors:
            if not isinstance(item, dict):
                continue
            name = canonicalize_behavior_name(item.get("behaviour", ""))
            if name and name not in seen:
                seen.add(name)
                names.append(name)
    return names


def behavior_to_column_name(name: str) -> str:
    lowered = re.sub(r"[^a-zA-Z0-9]+", "_", name.strip().lower()).strip("_")
    return lowered or "unknown_behavior"


def load_behavior_records(details_file: Path) -> List[Dict[str, Any]]:
    if not details_file.exists():
        return []

    records: List[Dict[str, Any]] = []
    with details_file.open("r", encoding="utf-8") as handle:
        for line in handle:
            stripped = line.strip()
            if not stripped:
                continue
            record = json.loads(stripped)
            if not isinstance(record, dict):
                continue
            records.append(record)
    return records


def collect_raw_behavior_names_from_records(
    records: Sequence[Mapping[str, Any]],
) -> List[str]:
    names = {
        name
        for record in records
        for name in extract_behavior_names(record)
        if name
    }
    return sorted(names)


def collect_raw_behavior_names_from_summary(
    summary_data: Mapping[str, Any],
) -> List[str]:
    names = summary_data.get("behavior_names", [])
    if not isinstance(names, list):
        return []
    return sorted(
        {
            canonicalize_behavior_name(name)
            for name in names
            if canonicalize_behavior_name(name)
        }
    )


def prune_taxonomy(
    taxonomy: Sequence[CategoryDefinition],
    assignments: Mapping[str, BehaviorCategoryAssignment],
) -> List[CategoryDefinition]:
    used_categories = {
        assignment.category_name for assignment in assignments.values() if assignment.category_name
    }
    pruned: List[CategoryDefinition] = []
    seen = set()
    for item in taxonomy:
        if item.category_name in used_categories and item.category_name not in seen:
            pruned.append(item)
            seen.add(item.category_name)

    for category_name in sorted(used_categories - seen):
        pruned.append(CategoryDefinition(
            category_name=category_name, description=""))
    return pruned


def load_fixed_category_map(
    path: Path,
) -> Tuple[Dict[str, BehaviorCategoryAssignment], List[CategoryDefinition]]:
    if not path.exists():
        raise FileNotFoundError(f"类别映射文件不存在: {path}")

    payload = load_json(path)
    if not payload:
        raise ValueError("类别映射文件为空。")

    assignments: Dict[str, BehaviorCategoryAssignment] = {}
    taxonomy: List[CategoryDefinition] = []
    duplicate_conflicts: List[Tuple[str, str, str]] = []

    for category_key, spec in payload.items():
        if not isinstance(spec, dict):
            raise ValueError(
                "当前 cluster_behaviors 数据格式要求顶层 value 为 object，"
                f"但 category={category_key!r} 的 value 是 {type(spec).__name__}。"
            )

        category_name = normalize_category_name(category_key)
        description = normalize_text(spec.get("description", ""))
        members = spec.get("tags", [])

        if not category_name:
            continue
        if not isinstance(members, list):
            raise ValueError(
                "当前 cluster_behaviors 数据格式要求每个类别包含 list 类型的 tags，"
                f"但 category={category_key!r} 的 tags 是 {type(members).__name__}。"
            )

        taxonomy.append(
            CategoryDefinition(
                category_name=category_name,
                description=description,
            )
        )

        normalized_members: List[str] = [
            canonicalize_behavior_name(category_name)]
        if isinstance(members, list):
            normalized_members.extend(
                canonicalize_behavior_name(member)
                for member in members
                if canonicalize_behavior_name(member)
            )

        for raw_behavior in normalized_members:
            existing = assignments.get(raw_behavior)
            if existing and existing.category_name != category_name:
                duplicate_conflicts.append(
                    (raw_behavior, existing.category_name, category_name)
                )
                continue
            assignments[raw_behavior] = BehaviorCategoryAssignment(
                raw_behavior=raw_behavior,
                category_name=category_name,
                source="fixed_category_map",
                rationale=description or f"Mapped by fixed category '{category_name}'.",
            )

    if duplicate_conflicts:
        unique_conflicts = {
            (raw_behavior, category_a, category_b)
            for raw_behavior, category_a, category_b in duplicate_conflicts
        }
        print(
            "[warn] cluster_behaviors.json 中存在重复 tag 归类，"
            "已按文件中的首个类别保留，后续重复项忽略。"
        )
        for raw_behavior, category_a, category_b in sorted(unique_conflicts):
            print(
                f"       tag={raw_behavior!r} first={category_a!r} ignored={category_b!r}"
            )

    return assignments, prune_taxonomy(taxonomy, assignments)


def build_behavior_category_map(
    raw_behavior_names: Sequence[str],
    args: argparse.Namespace,
) -> Tuple[Dict[str, BehaviorCategoryAssignment], List[CategoryDefinition]]:
    requested_names = sorted(
        {
            canonicalize_behavior_name(name)
            for name in raw_behavior_names
            if canonicalize_behavior_name(name)
        }
    )
    if not requested_names:
        return {}, []

    all_assignments, taxonomy = load_fixed_category_map(
        args.category_map_file.expanduser())

    missing = [name for name in requested_names if name not in all_assignments]
    if missing and args.strict_category_map:
        raise ValueError(
            "以下 raw behaviors 未在类别映射中定义:\n"
            + "\n".join(f"- {name}" for name in missing)
        )

    filtered_assignments: Dict[str, BehaviorCategoryAssignment] = {}
    for name in requested_names:
        assignment = all_assignments.get(name)
        if assignment is None:
            assignment = BehaviorCategoryAssignment(
                raw_behavior=name,
                category_name=name,
                source="fallback_identity",
                rationale="Not found in fixed category map; use raw behavior as its own category.",
            )
        filtered_assignments[name] = assignment

    return filtered_assignments, prune_taxonomy(taxonomy, filtered_assignments)


def aggregate_from_records(
    records: Sequence[Mapping[str, Any]],
    category_map: Mapping[str, BehaviorCategoryAssignment],
) -> Dict[Tuple[str, int], Dict[str, Any]]:
    step_stats: Dict[Tuple[str, int], Dict[str, Any]] = {}

    for record in records:
        run = normalize_text(record.get("run", ""))
        if not run:
            continue
        try:
            step = int(record.get("step"))
        except (TypeError, ValueError):
            continue

        step_key = (run, step)
        if step_key not in step_stats:
            step_stats[step_key] = {
                "run": run,
                "step": step,
                "rollout_count": 0,
                "reasoning_behavior_count": 0,
                "reasoning_behavior_type_count": 0,
                "category_counts": Counter(),
                "category_names": [],
            }

        per_rollout_categories = set()
        for raw_behavior in extract_behavior_names(record):
            assignment = category_map.get(raw_behavior)
            category_name = assignment.category_name if assignment else raw_behavior
            if category_name:
                per_rollout_categories.add(category_name)

        step_stats[step_key]["rollout_count"] += 1
        step_stats[step_key]["reasoning_behavior_count"] += len(
            per_rollout_categories)
        step_stats[step_key]["category_counts"].update(per_rollout_categories)

    for step_key, row in step_stats.items():
        category_counts = row["category_counts"]
        row["category_names"] = sorted(category_counts.keys())
        row["reasoning_behavior_type_count"] = len(category_counts)
        row["category_counts"] = {
            name: int(category_counts[name]) for name in sorted(category_counts.keys())
        }
        step_stats[step_key] = row
    return step_stats


def aggregate_from_summary_json(
    summary_data: Mapping[str, Any],
    category_map: Mapping[str, BehaviorCategoryAssignment],
) -> Dict[Tuple[str, int], Dict[str, Any]]:
    behavior_names = collect_raw_behavior_names_from_summary(summary_data)
    step_rows = summary_data.get("step_summary", [])
    if not isinstance(step_rows, list):
        raise ValueError("summary.json 中 step_summary 不是 list。")

    step_stats: Dict[Tuple[str, int], Dict[str, Any]] = {}
    for row in step_rows:
        if not isinstance(row, dict):
            continue
        run = normalize_text(row.get("run", ""))
        if not run:
            continue
        try:
            step = int(row.get("step"))
        except (TypeError, ValueError):
            continue

        category_counts: Counter[str] = Counter()
        for raw_behavior in behavior_names:
            column_name = f"{behavior_to_column_name(raw_behavior)}_count"
            raw_count = row.get(column_name, 0)
            try:
                count = int(raw_count)
            except (TypeError, ValueError):
                continue
            if count <= 0:
                continue
            assignment = category_map.get(raw_behavior)
            category_name = assignment.category_name if assignment else raw_behavior
            if category_name:
                category_counts[category_name] += count

        step_stats[(run, step)] = {
            "run": run,
            "step": step,
            "rollout_count": int(row.get("rollout_count", 0) or 0),
            "reasoning_behavior_count": int(sum(category_counts.values())),
            "reasoning_behavior_type_count": len(category_counts),
            "category_names": sorted(category_counts.keys()),
            "category_counts": {
                name: int(category_counts[name]) for name in sorted(category_counts.keys())
            },
            "approximate_from_summary": True,
        }
    return step_stats


def write_categorized_step_json(
    path: Path,
    step_stats: Mapping[Tuple[str, int], Mapping[str, Any]],
    category_map: Mapping[str, BehaviorCategoryAssignment],
    category_taxonomy: Sequence[CategoryDefinition],
    source_kind: str,
    category_map_file: Path,
) -> None:
    payload = {
        "generated_at": datetime.now().isoformat(),
        "source_kind": source_kind,
        "taxonomy": [
            taxonomy_item.to_dict()
            for taxonomy_item in prune_taxonomy(category_taxonomy, category_map)
        ],
        "category_map_file": str(category_map_file),
        "category_assignments": [
            category_map[name].to_dict()
            for name in sorted(category_map.keys())
        ],
        "step_summary": [
            dict(step_stats[key])
            for key in sorted(step_stats.keys(), key=lambda item: (item[0], item[1]))
        ],
    }
    save_json(path, payload)


def update_step_entry(
    entry: MutableMapping[str, object],
    step_stats: Mapping[Tuple[str, int], Mapping[str, Any]],
    run_config: ExperimentConfig,
    type_field: str,
    count_field: str,
) -> bool:
    step = entry.get("step")
    try:
        step_int = int(step)
    except (TypeError, ValueError):
        return False

    row = None
    for run_name in iter_candidate_run_names(run_config):
        row = step_stats.get((run_name, step_int))
        if row is not None:
            break
    if row is None:
        return False

    entry[type_field] = to_builtin(row.get("reasoning_behavior_type_count"))
    entry[count_field] = to_builtin(row.get("reasoning_behavior_count"))
    return True


def build_step_stats(
    args: argparse.Namespace,
) -> Tuple[
    Dict[Tuple[str, int], Dict[str, Any]],
    Dict[str, BehaviorCategoryAssignment],
    List[CategoryDefinition],
    str,
    List[str],
]:
    details_file = args.details_file.expanduser()
    summary_json_file = args.summary_json_file.expanduser()

    if details_file.exists():
        records = load_behavior_records(details_file)
        if not records:
            raise ValueError(f"details 文件为空: {details_file}")
        raw_behavior_names = collect_raw_behavior_names_from_records(records)
        category_map, category_taxonomy = build_behavior_category_map(
            raw_behavior_names=raw_behavior_names,
            args=args,
        )
        step_stats = aggregate_from_records(records, category_map)
        return (
            step_stats,
            category_map,
            category_taxonomy,
            "details_jsonl",
            raw_behavior_names,
        )

    if summary_json_file.exists():
        summary_data = load_json(summary_json_file)
        raw_behavior_names = collect_raw_behavior_names_from_summary(
            summary_data)
        category_map, category_taxonomy = build_behavior_category_map(
            raw_behavior_names=raw_behavior_names,
            args=args,
        )
        step_stats = aggregate_from_summary_json(summary_data, category_map)
        return (
            step_stats,
            category_map,
            category_taxonomy,
            "summary_json_approx",
            raw_behavior_names,
        )

    raise FileNotFoundError("既找不到 details_file，也找不到 summary_json_file。")


def main() -> None:
    args = parse_args()
    input_json = args.input_json.expanduser()
    output_json = args.output_json.expanduser() if args.output_json else input_json

    if not input_json.exists():
        raise FileNotFoundError(f"输入 JSON 不存在: {input_json}")

    data = load_json(input_json)
    experiment_configs = build_experiment_configs()

    (
        step_stats,
        category_map,
        category_taxonomy,
        source_kind,
        raw_behavior_names,
    ) = build_step_stats(args)
    if not args.dry_run:
        write_categorized_step_json(
            path=args.categorized_step_json.expanduser(),
            step_stats=step_stats,
            category_map=category_map,
            category_taxonomy=category_taxonomy,
            source_kind=source_kind,
            category_map_file=args.category_map_file.expanduser(),
        )

    if source_kind == "summary_json_approx":
        print(
            "[warn] 当前使用 summary.json 做近似聚合，"
            "reasoning_behavior_count 未做 rollout 内去重。"
        )

    fallback_names = sorted(
        name
        for name, assignment in category_map.items()
        if assignment.source == "fallback_identity"
    )

    available_runs = sorted({run for run, _ in step_stats.keys()})
    print(f"[info] reasoning behavior source={source_kind}")
    print(f"[info] available runs from reasoning stats: {available_runs}")
    print(f"[info] unique raw behavior labels: {len(raw_behavior_names)}")
    print(
        "[info] unique behavior categories: "
        f"{len({assignment.category_name for assignment in category_map.values()})}"
    )
    if fallback_names:
        print(
            "[warn] 以下 behaviors 未命中固定类别映射，已按原名作为独立类别处理: "
            f"{fallback_names}"
        )

    for json_key, payload in data.items():
        config = resolve_json_section(json_key, experiment_configs)
        if config is None:
            if args.verbose:
                print(f"[skip] section '{json_key}' 不在预设映射中，跳过。")
            continue

        if not isinstance(payload, list):
            print(f"[skip] section '{json_key}' 不是 list，跳过。")
            continue

        updated_rows = 0
        missing_steps: List[int] = []

        for item in payload:
            if not isinstance(item, dict):
                continue
            updated = update_step_entry(
                entry=item,
                step_stats=step_stats,
                run_config=config,
                type_field=args.type_field,
                count_field=args.count_field,
            )
            if updated:
                updated_rows += 1
            else:
                step = item.get("step")
                try:
                    missing_steps.append(int(step))
                except (TypeError, ValueError):
                    pass

        print(
            f"[done] section='{json_key}' rows={len(payload)} "
            f"updated={updated_rows} missing_steps={len(missing_steps)}"
        )
        if args.verbose and missing_steps:
            print(f"       missing step(s): {sorted(set(missing_steps))}")

    if args.dry_run:
        print("[dry-run] 未写入 baseline-length-analysis_results.json。")
        return

    if args.backup and output_json == input_json:
        backup_path = input_json.with_suffix(input_json.suffix + ".bak")
        save_json(backup_path, load_json(input_json))
        print(f"[backup] 已写入备份: {backup_path}")

    save_json(output_json, data)
    print(f"[write] 已写入: {output_json}")
    print(f"[write] 固定类别映射: {args.category_map_file.expanduser()}")
    print(f"[write] step 类别汇总: {args.categorized_step_json.expanduser()}")


if __name__ == "__main__":
    main()
