from __future__ import annotations

import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import Sequence
from typing import Tuple

import pandas as pd
import wandb


DEFAULT_INPUT_JSON = Path(
    "results/baseline-length-analysis_results.json"
)

DEFAULT_OUTPUT_FIELD_MAP = {
    "actor/entropy": "entropy",
}

DEFAULT_EXPERIMENTS = {
    "baseline": {
        "aliases": ["baseline"],
        "reward_metric": "critic/score/mean",
        "reward_field": "acc_reward",
        "runs": [
            "astrid_tuning_llm/verl-qwen3-4b-oct/1z81wc9e",
            "astrid_tuning_llm/verl-qwen3-4b-oct/u9m3tt9b",
        ],
    },
    "gspo_length": {
        "aliases": ["gspo_length", "gspo length", "length", "ours"],
        "reward_metric": "critic/acc/mean",
        "reward_field": "acc_reward",
        "runs": [
            "astrid_tuning_llm/verl-qwen3-4b-oct/7101lp3l",
            "astrid_tuning_llm/verl-qwen3-4b-oct/clk5pe0g",
            "astrid_tuning_llm/verl-qwen3-4b-oct/nyda3ljt"
        ],
    },
}


@dataclass(frozen=True)
class ExperimentConfig:
    canonical_name: str
    aliases: Tuple[str, ...]
    reward_metric: str
    reward_field: str
    runs: Tuple[str, ...]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "从 W&B online 拉取 baseline / gspo_length 的训练指标，"
            "按 step 补充到 baseline-length-analysis_results.json。"
        )
    )
    parser.add_argument(
        "--input-json",
        type=Path,
        default=DEFAULT_INPUT_JSON,
        help="待补全的输入 JSON 路径。",
    )
    parser.add_argument(
        "--output-json",
        type=Path,
        default=None,
        help="输出 JSON 路径。默认原地覆盖输入文件。",
    )
    parser.add_argument(
        "--step-column",
        type=str,
        default="_step",
        help="W&B history 中用于对齐 JSON step 的列名。",
    )
    parser.add_argument(
        "--write-reward-field",
        type=str,
        default="acc_reward",
        help=(
            "reward 在结果 JSON 中写入的字段名。"
            "默认统一写成 acc_reward。"
        ),
    )
    parser.add_argument(
        "--entropy-field",
        type=str,
        default="entropy",
        help="entropy 在结果 JSON 中写入的字段名。",
    )
    parser.add_argument(
        "--backup",
        action="store_true",
        help="原地覆盖前先写一个 .bak 备份。",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="只打印补全情况，不写文件。",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="打印更多拉取和对齐细节。",
    )
    return parser.parse_args()


def normalize_name(name: str) -> str:
    return "".join(ch.lower() for ch in name if ch.isalnum())


def build_experiment_configs(
    reward_field_name: str,
) -> Dict[str, ExperimentConfig]:
    configs: Dict[str, ExperimentConfig] = {}
    for canonical_name, spec in DEFAULT_EXPERIMENTS.items():
        configs[canonical_name] = ExperimentConfig(
            canonical_name=canonical_name,
            aliases=tuple(spec["aliases"]),
            reward_metric=spec["reward_metric"],
            reward_field=reward_field_name,
            runs=tuple(spec["runs"]),
        )
    return configs


def resolve_json_section(
    json_key: str,
    experiment_configs: Mapping[str, ExperimentConfig],
) -> Optional[ExperimentConfig]:
    normalized_key = normalize_name(json_key)
    for config in experiment_configs.values():
        candidates = (config.canonical_name, *config.aliases)
        if normalized_key in {normalize_name(candidate) for candidate in candidates}:
            return config
    return None


def fetch_run_history(
    api: wandb.Api,
    run_path: str,
    requested_columns: Sequence[str],
    verbose: bool = False,
) -> pd.DataFrame:
    run = api.run(run_path)
    rows = []
    for row in run.scan_history(page_size=1000):
        if row is None:
            continue
        row_dict = dict(row)
        if not row_dict:
            continue
        if requested_columns:
            row_dict = {key: row_dict.get(
                key) for key in requested_columns if key in row_dict}
        rows.append(row_dict)

    df = pd.DataFrame(rows)
    if verbose:
        print(
            f"[wandb] fetched {len(df):>4} rows from {run_path} ({run.name})")
    return df


def prepare_run_dataframe(
    df: pd.DataFrame,
    step_column: str,
    metric_columns: Iterable[str],
) -> pd.DataFrame:
    if df.empty:
        return pd.DataFrame(columns=[step_column, *metric_columns])

    work = df.copy()
    keep_columns = [step_column, *metric_columns]
    keep_columns = [
        column for column in keep_columns if column in work.columns]
    work = work[keep_columns].copy()

    if step_column not in work.columns:
        raise ValueError(f"W&B history 中找不到 step 列: {step_column}")

    work[step_column] = pd.to_numeric(work[step_column], errors="coerce")
    work = work.dropna(subset=[step_column]).copy()
    work[step_column] = work[step_column].astype(int)

    for column in metric_columns:
        if column in work.columns:
            work[column] = pd.to_numeric(work[column], errors="coerce")

    work = work.sort_values(step_column).drop_duplicates(
        subset=[step_column], keep="last")
    return work


def merge_run_histories(
    api: wandb.Api,
    run_paths: Sequence[str],
    step_column: str,
    metric_columns: Sequence[str],
    verbose: bool = False,
) -> pd.DataFrame:
    prepared_frames: List[pd.DataFrame] = []
    for run_path in run_paths:
        fetched = fetch_run_history(
            api=api,
            run_path=run_path,
            requested_columns=[step_column, *metric_columns],
            verbose=verbose,
        )
        prepared = prepare_run_dataframe(
            df=fetched,
            step_column=step_column,
            metric_columns=metric_columns,
        )
        if verbose and not prepared.empty:
            print(
                "[wandb] "
                f"{run_path} step range: {prepared[step_column].min()}..{prepared[step_column].max()}"
            )
        prepared_frames.append(prepared)

    if not prepared_frames:
        return pd.DataFrame(columns=[step_column, *metric_columns])

    merged = pd.concat(prepared_frames, ignore_index=True, sort=False)
    merged = merged.sort_values(step_column)
    merged = merged.drop_duplicates(subset=[step_column], keep="last")
    merged = merged.set_index(step_column).sort_index()
    return merged


def to_builtin(value: object) -> object:
    if pd.isna(value):
        return None
    if isinstance(value, (int, str, bool)) or value is None:
        return value
    if isinstance(value, float):
        if math.isnan(value) or math.isinf(value):
            return None
        return value
    if hasattr(value, "item"):
        try:
            return to_builtin(value.item())
        except Exception:
            return value
    return value


def update_step_entry(
    entry: MutableMapping[str, object],
    metrics_by_step: pd.DataFrame,
    source_to_target_fields: Mapping[str, str],
) -> bool:
    step = entry.get("step")
    if step is None:
        return False

    try:
        step_int = int(step)
    except (TypeError, ValueError):
        return False

    if step_int not in metrics_by_step.index:
        return False

    row = metrics_by_step.loc[step_int]
    if isinstance(row, pd.DataFrame):
        row = row.iloc[-1]

    updated = False
    for source_field, target_field in source_to_target_fields.items():
        if source_field not in row.index:
            continue
        value = to_builtin(row[source_field])
        if value is None:
            continue
        entry[target_field] = value
        updated = True
    return updated


def load_json(path: Path) -> MutableMapping[str, object]:
    with path.open("r", encoding="utf-8") as handle:
        data = json.load(handle)
    if not isinstance(data, dict):
        raise ValueError(
            f"输入 JSON 顶层必须是 object/dict，实际是 {type(data).__name__}")
    return data


def save_json(path: Path, data: Mapping[str, object]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as handle:
        json.dump(data, handle, indent=2, ensure_ascii=False)
        handle.write("\n")


def main() -> None:
    args = parse_args()
    input_json = args.input_json.expanduser()
    output_json = args.output_json.expanduser() if args.output_json else input_json

    if not input_json.exists():
        raise FileNotFoundError(f"输入 JSON 不存在: {input_json}")

    experiment_configs = build_experiment_configs(
        reward_field_name=args.write_reward_field)
    data = load_json(input_json)

    api = wandb.Api(timeout=30)

    for json_key, payload in data.items():
        config = resolve_json_section(json_key, experiment_configs)
        if config is None:
            if args.verbose:
                print(f"[skip] section '{json_key}' 不在预设映射中，跳过。")
            continue

        if not isinstance(payload, list):
            print(f"[skip] section '{json_key}' 不是 list，跳过。")
            continue

        metric_mapping = dict(DEFAULT_OUTPUT_FIELD_MAP)
        metric_mapping[config.reward_metric] = config.reward_field

        metric_columns = list(metric_mapping.keys())
        merged_metrics = merge_run_histories(
            api=api,
            run_paths=config.runs,
            step_column=args.step_column,
            metric_columns=metric_columns,
            verbose=args.verbose,
        )

        found_steps: List[int] = []
        missing_steps: List[int] = []
        updated_rows = 0

        source_to_target_fields = {
            "actor/entropy": args.entropy_field,
            config.reward_metric: config.reward_field,
        }

        for item in payload:
            if not isinstance(item, dict):
                continue
            step = item.get("step")
            try:
                step_int = int(step)
            except (TypeError, ValueError):
                continue

            if step_int in merged_metrics.index:
                found_steps.append(step_int)
            else:
                missing_steps.append(step_int)

            if update_step_entry(
                entry=item,
                metrics_by_step=merged_metrics,
                source_to_target_fields=source_to_target_fields,
            ):
                updated_rows += 1

        print(
            f"[done] section='{json_key}' "
            f"rows={len(payload)} updated={updated_rows} "
            f"found_steps={len(found_steps)} missing_steps={len(missing_steps)}"
        )
        if missing_steps:
            print(f"       missing step(s): {missing_steps}")

    if args.dry_run:
        print("[dry-run] 未写入文件。")
        return

    if args.backup and output_json == input_json:
        backup_path = input_json.with_suffix(input_json.suffix + ".bak")
        save_json(backup_path, load_json(input_json))
        print(f"[backup] 已写入备份: {backup_path}")

    save_json(output_json, data)
    print(f"[write] 已写入: {output_json}")


if __name__ == "__main__":
    main()
