#!/usr/bin/env python3
"""Prepare OpenAI RFT JSONL data from cached KataGo PV analysis."""

from __future__ import annotations

import argparse
import json
import random
import re
from pathlib import Path
from typing import Any

GO_COLUMNS = "ABCDEFGHJKLMNOPQRST"
SYSTEM_PROMPT = "You are an expert Go analysis model. Analyze the position and output strict JSON only."
MIN_TOP_MOVE_VISITS = 10
MIN_RANK1_VISITS = 100


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="openairft/katago_pv_1k_2kvisits.jsonl")
    parser.add_argument("--source-reviews", default="openairft/dan_reviews_1d_plus.jsonl")
    parser.add_argument("--out-dir", default="openairft/rft_katago")
    parser.add_argument("--train-size", type=int, default=900)
    parser.add_argument("--validation-size", type=int, default=100)
    parser.add_argument("--seed", type=int, default=20260503)
    return parser.parse_args()


def parse_board_state(board_state: str) -> tuple[list[list[int]], int, dict[str, list[str]]]:
    matrix = [[0 for _ in range(19)] for _ in range(19)]
    stones = {"black": [], "white": []}
    text = str(board_state or "")
    if "|" in text and text.lower().startswith("move:"):
        text = text.split("|", 1)[1]
    coord_to_col = {col: idx for idx, col in enumerate(GO_COLUMNS)}
    seen = set()
    for item in text.split(","):
        item = item.strip()
        match = re.fullmatch(r"([BW])\d+:([A-HJ-T])(\d{1,2})", item)
        if not match:
            continue
        color, col_text, row_text = match.groups()
        row_num = int(row_text)
        if col_text not in coord_to_col or not 1 <= row_num <= 19:
            continue
        row = 19 - row_num
        col = coord_to_col[col_text]
        matrix[row][col] = 1 if color == "B" else -1
        stones["black" if color == "B" else "white"].append(f"{col_text}{row_num}")
        seen.add((row, col))
    stones["black"].sort()
    stones["white"].sort()
    return matrix, len(seen), stones


def read_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open(encoding="utf-8") as handle:
        return [json.loads(line) for line in handle if line.strip()]


def reference_from_row(row: dict[str, Any], stones: dict[str, list[str]]) -> dict[str, Any] | None:
    pvs = row.get("principal_variations") or []
    if not pvs:
        return None
    top1 = pvs[0]
    top1_visits = int(top1.get("visits") or 0)
    if top1_visits < MIN_RANK1_VISITS:
        return None
    filtered_pvs = [pv for pv in pvs if int(pv.get("visits") or 0) >= MIN_TOP_MOVE_VISITS]
    top_moves = [str(pv.get("move", "")).upper() for pv in filtered_pvs if pv.get("move")]
    pv_top1 = [str(move).upper() for move in (top1.get("pv") or [])][:12]
    if not top_moves or not pv_top1:
        return None
    winrate = top1.get("winrate", row.get("top_move_winrate", row.get("root_winrate")))
    score = top1.get("scoreLead", row.get("top_move_scoreLead", row.get("root_scoreLead")))
    if winrate is None or score is None:
        return None
    return {
        "top_moves": top_moves[:10],
        "pv_top1": pv_top1,
        "winrate_black": round(float(winrate) * 100.0, 4),
        "score_lead_black": round(float(score), 4),
        "rank1_visits": top1_visits,
        "min_top_move_visits": MIN_TOP_MOVE_VISITS,
        "occupied_points": sorted(stones["black"] + stones["white"]),
        "black_points": list(stones["black"]),
        "white_points": list(stones["white"]),
    }


def build_user_prompt(
    row: dict[str, Any],
    board_matrix: list[list[int]],
    stones: dict[str, list[str]],
) -> str:
    matrix_json = json.dumps(board_matrix, separators=(",", ":"))
    side_to_play = "Black" if row.get("initial_player") == "B" else "White"
    black_stones = ", ".join(stones["black"]) if stones["black"] else "none"
    white_stones = ", ".join(stones["white"]) if stones["white"] else "none"
    return (
        f"{SYSTEM_PROMPT}\n\n"
        f"Move number: {row.get('move_number')}\n"
        f"To play: {side_to_play}\n"
        "Rules: Japanese, Komi 6.5\n\n"
        "Board coordinates: rows are indexed 1-19 from board row 19 down to "
        "board row 1. Columns correspond to A, B, C, D, "
        "E, F, G, H, J, K, L, M, N, O, P, Q, R, S, T.\n"
        "Each array is one board row. The first array is row 19, and the first "
        "element in row 19 is A19.\n"
        "Board values: 1 = black, -1 = white, 0 = neutral/empty.\n\n"
        f"Black stones: {black_stones}\n"
        f"White stones: {white_stones}\n\n"
        f"{matrix_json}\n\n"
        "Analyze this position. Output:\n"
        "- A brief explanation (max 150 words) of the key features of the position and why the best move is correct\n"
        "- The best move\n"
        "- The top principal variation (up to 12 moves)\n"
        "- Black win rate as a percentage (0-100)\n"
        "- Black score lead in points (positive = Black ahead)\n\n"
        "Do not output moves that are already occupied in the initial position."
    )


def build_rft_item(row: dict[str, Any], source_by_id: dict[int, dict[str, Any]]) -> dict[str, Any] | None:
    source = source_by_id.get(int(row["id"]))
    if not source or not source.get("board_state"):
        return None
    matrix, stone_count, stones = parse_board_state(source["board_state"])
    if stone_count == 0:
        return None
    ref = reference_from_row(row, stones)
    if ref is None:
        return None
    return {
        "messages": [
            {"role": "user", "content": build_user_prompt(row, matrix, stones)},
        ],
        "reference": ref,
        "metadata": {
            "id": row.get("id"),
            "game_file": row.get("game_file"),
            "move_number": row.get("move_number"),
            "initial_player": row.get("initial_player"),
            "reviewer_rank": row.get("reviewer_rank"),
            "katago_visits": row.get("katago_visits"),
            "parsed_stone_count": row.get("parsed_stone_count"),
        },
    }


def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
    with path.open("w", encoding="utf-8") as handle:
        for row in rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")


def main() -> int:
    args = parse_args()
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    source_rows = read_jsonl(Path(args.source_reviews))
    source_by_id = {int(row["id"]): row for row in source_rows if row.get("id") is not None}
    katago_rows = read_jsonl(Path(args.input))

    items = [build_rft_item(row, source_by_id) for row in katago_rows]
    items = [item for item in items if item is not None]
    random.Random(args.seed).shuffle(items)

    validation_size = min(args.validation_size, max(1, len(items) // 10))
    train_size = min(args.train_size, len(items) - validation_size)
    if train_size <= 0 or validation_size <= 0:
        raise RuntimeError(f"Only {len(items)} usable examples after filtering")

    train_rows = items[:train_size]
    validation_rows = items[train_size : train_size + validation_size]
    train_path = out_dir / "katago_rft_train.jsonl"
    validation_path = out_dir / "katago_rft_validation.jsonl"
    metadata_path = out_dir / "katago_rft_manifest.json"

    write_jsonl(train_path, train_rows)
    write_jsonl(validation_path, validation_rows)
    metadata_path.write_text(
        json.dumps(
            {
                "input": args.input,
                "source_reviews": args.source_reviews,
                "train_rows": len(train_rows),
                "validation_rows": len(validation_rows),
                "seed": args.seed,
                "train_path": str(train_path),
                "validation_path": str(validation_path),
                "base_model": "o4-mini-2025-04-16",
                "reward_weights": {
                    "move": 0.35,
                    "pv": 0.30,
                    "winrate": 0.15,
                    "score": 0.15,
                    "format": 0.05,
                },
                "filters": {
                    "min_top_move_visits": MIN_TOP_MOVE_VISITS,
                    "min_rank1_visits": MIN_RANK1_VISITS,
                },
            },
            indent=2,
        )
        + "\n",
        encoding="utf-8",
    )
    print(f"Wrote {len(train_rows)} train rows to {train_path}")
    print(f"Wrote {len(validation_rows)} validation rows to {validation_path}")
    print(f"Wrote manifest to {metadata_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
