#!/usr/bin/env python3
"""Run one RFT-style inference and ask the model to repair illegal occupied moves."""

from __future__ import annotations

import argparse
import json
import random
from pathlib import Path
from typing import Any

from openai import OpenAI

from rft_katago_grader import GRADER_SOURCE, response_format


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--game-file", default="5732-zrilak-ouigkmy-altea.sgf")
    parser.add_argument("--move-number", type=int, default=67)
    parser.add_argument("--sample-n", type=int, default=0)
    parser.add_argument("--seed", type=int, default=20260504)
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--train-path", default="openairft/rft_katago/katago_rft_train.jsonl")
    parser.add_argument("--validation-path", default="openairft/rft_katago/katago_rft_validation.jsonl")
    parser.add_argument("--output", default="openairft/rft_katago/o4_mini_single_inference_5732_move67.json")
    parser.add_argument("--model", default="o4-mini-2025-04-16")
    parser.add_argument("--reasoning-effort", default="medium")
    parser.add_argument("--max-completion-tokens", type=int, default=10000)
    parser.add_argument("--max-repairs", type=int, default=2)
    return parser.parse_args()


def read_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open(encoding="utf-8") as handle:
        return [json.loads(line) for line in handle if line.strip()]


def find_item(args: argparse.Namespace) -> tuple[dict[str, Any], str]:
    for path in [args.train_path, args.validation_path]:
        for row in read_jsonl(Path(path)):
            meta = row.get("metadata", {})
            if meta.get("game_file") == args.game_file and meta.get("move_number") == args.move_number:
                return row, path
    raise RuntimeError(f"Could not find {args.game_file} move {args.move_number}")


def sample_items(args: argparse.Namespace) -> list[tuple[dict[str, Any], str]]:
    rows: list[tuple[dict[str, Any], str]] = []
    for path in [args.validation_path, args.train_path]:
        for row in read_jsonl(Path(path)):
            rows.append((row, path))
    rng = random.Random(args.seed)
    return rng.sample(rows, min(args.sample_n, len(rows)))


def item_key(row: dict[str, Any]) -> str:
    meta = row.get("metadata", {})
    return f"{meta.get('game_file')}::{meta.get('move_number')}::{meta.get('id')}"


def completed_keys(path: Path) -> set[str]:
    if not path.exists():
        return set()
    done = set()
    with path.open(encoding="utf-8") as handle:
        for line in handle:
            if not line.strip():
                continue
            try:
                row = json.loads(line)
            except Exception:
                continue
            done.add(item_key(row))
    return done


def build_grader():
    namespace: dict[str, Any] = {}
    exec(GRADER_SOURCE, namespace)
    return namespace["grade"]


def occupied_from_prompt(prompt: str) -> set[str]:
    occupied: set[str] = set()
    for label in ["Black stones:", "White stones:"]:
        if label not in prompt:
            continue
        rest = prompt.split(label, 1)[1]
        line = rest.splitlines()[0]
        if line.strip().lower() == "none":
            continue
        occupied.update(coord.strip().upper() for coord in line.split(",") if coord.strip())
    return occupied


def normalize_move(move: Any) -> str:
    return str(move).strip().upper()


COLS = "ABCDEFGHJKLMNOPQRST"


def coord_to_rc(move: str) -> tuple[int, int] | None:
    move = normalize_move(move)
    if move == "PASS":
        return None
    if len(move) < 2 or move[0] not in COLS:
        return None
    try:
        row_num = int(move[1:])
    except ValueError:
        return None
    if not 1 <= row_num <= 19:
        return None
    return 19 - row_num, COLS.index(move[0])


def neighbors(r: int, c: int):
    for rr, cc in ((r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)):
        if 0 <= rr < 19 and 0 <= cc < 19:
            yield rr, cc


def group_and_liberties(board: list[list[int]], r: int, c: int):
    color = board[r][c]
    stack = [(r, c)]
    seen = set()
    liberties = set()
    while stack:
        rr, cc = stack.pop()
        if (rr, cc) in seen:
            continue
        seen.add((rr, cc))
        for nr, nc in neighbors(rr, cc):
            value = board[nr][nc]
            if value == 0:
                liberties.add((nr, nc))
            elif value == color and (nr, nc) not in seen:
                stack.append((nr, nc))
    return seen, liberties


def board_from_reference(reference: dict[str, Any]) -> list[list[int]]:
    board = [[0 for _ in range(19)] for _ in range(19)]
    for coord in reference.get("black_points", []):
        rc = coord_to_rc(coord)
        if rc is not None:
            board[rc[0]][rc[1]] = 1
    for coord in reference.get("white_points", []):
        rc = coord_to_rc(coord)
        if rc is not None:
            board[rc[0]][rc[1]] = -1
    return board


def play_move(board: list[list[int]], move: str, color: int) -> bool:
    move = normalize_move(move)
    if move == "PASS":
        return True
    rc = coord_to_rc(move)
    if rc is None:
        return False
    r, c = rc
    if board[r][c] != 0:
        return False
    board[r][c] = color
    opponent = -color
    captured = False
    for nr, nc in neighbors(r, c):
        if board[nr][nc] != opponent:
            continue
        group, liberties = group_and_liberties(board, nr, nc)
        if not liberties:
            captured = True
            for gr, gc in group:
                board[gr][gc] = 0
    _, own_liberties = group_and_liberties(board, r, c)
    if not own_liberties and not captured:
        board[r][c] = 0
        return False
    return True


def find_legality_violations(sample: dict[str, Any], row: dict[str, Any]) -> list[str]:
    violations: list[str] = []
    best_move = normalize_move(sample.get("best_move", ""))
    pv = sample.get("pv_top1", [])
    if not isinstance(pv, list):
        return ["pv_top1 is not a list"]
    moves = [best_move] + [normalize_move(move) for move in (pv[1:] if pv and normalize_move(pv[0]) == best_move else pv)]
    color = 1 if row.get("metadata", {}).get("initial_player") == "B" else -1
    board = board_from_reference(row["reference"])
    for idx, move in enumerate(moves, start=1):
        if not play_move(board, move, color):
            label = "best_move" if idx == 1 else f"pv_top1[{idx}]"
            violations.append(f"{label} {move} is illegal at that point in the line")
        color = -color
    return violations


def call_model(
    client: OpenAI,
    model: str,
    messages: list[dict[str, str]],
    reasoning_effort: str,
    max_completion_tokens: int,
) -> tuple[str, dict[str, Any] | None, Any]:
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        response_format=response_format(),
        reasoning_effort=reasoning_effort,
        max_completion_tokens=max_completion_tokens,
    )
    content = response.choices[0].message.content or ""
    try:
        parsed = json.loads(content)
    except Exception:
        parsed = None
    return content, parsed, response.usage


def run_one(
    args: argparse.Namespace,
    row: dict[str, Any],
    source_path: str,
    grade_fn=None,
) -> dict[str, Any]:
    original_messages = row["messages"]
    prompt = original_messages[0]["content"]
    client = OpenAI()

    attempts = []
    messages = original_messages
    final_content = ""
    final_parsed = None
    final_usage = None
    final_violations: list[str] = []

    for attempt in range(args.max_repairs + 1):
        content, parsed, usage = call_model(
            client,
            model=args.model,
            messages=messages,
            reasoning_effort=args.reasoning_effort,
            max_completion_tokens=args.max_completion_tokens,
        )
        violations = ["output did not parse as JSON"] if parsed is None else find_legality_violations(parsed, row)
        attempts.append(
            {
                "attempt": attempt + 1,
                "raw_output": content,
                "parsed_output": parsed,
                "occupied_violations": violations,
                "usage": usage.model_dump() if usage else None,
            }
        )
        final_content = content
        final_parsed = parsed
        final_usage = usage
        final_violations = violations
        if not violations:
            break

        repair_prompt = (
            prompt
            + "\n\nYour previous JSON output was illegal because:\n"
            + "\n".join(f"- {violation}" for violation in violations)
            + "\n\nPrevious output:\n"
            + content
            + "\n\nRepair the JSON. Every move must be legal at the point it is played; "
            "playing on a point is allowed only if it has become empty after a capture. "
            "Return strict JSON only with the same fields."
        )
        messages = [{"role": "user", "content": repair_prompt}]

    out = {
        "source_path": source_path,
        "metadata": row["metadata"],
        "prompt_messages": original_messages,
        "occupied_points": sorted(row["reference"].get("occupied_points", [])),
        "reference": row["reference"],
        "model": args.model,
        "reasoning_effort": args.reasoning_effort,
        "max_completion_tokens": args.max_completion_tokens,
        "raw_output": final_content,
        "parsed_output": final_parsed,
        "occupied_violations": final_violations,
        "attempts": attempts,
        "usage": final_usage.model_dump() if final_usage else None,
    }
    if grade_fn is not None and final_parsed is not None:
        out["grader_score"] = grade_fn(final_parsed, row)
    else:
        out["grader_score"] = 0.0
    return out


def main() -> int:
    args = parse_args()
    if args.sample_n > 0:
        items = sample_items(args)
        output_path = Path(args.output)
        if output_path.suffix != ".jsonl":
            output_path = output_path.with_suffix(".jsonl")
        done = completed_keys(output_path) if args.resume else set()
        grade_fn = build_grader()
        written = 0
        skipped = 0
        for idx, (row, source_path) in enumerate(items, start=1):
            meta = row.get("metadata", {})
            if item_key(row) in done:
                skipped += 1
                continue
            print(
                f"Running {idx}/{len(items)}: {meta.get('game_file')} move {meta.get('move_number')}",
                flush=True,
            )
            out = run_one(args, row, source_path, grade_fn=grade_fn)
            with output_path.open("a", encoding="utf-8") as handle:
                handle.write(json.dumps(out, ensure_ascii=False) + "\n")
            written += 1
            print(
                f"  attempts={len(out['attempts'])} "
                f"violations={out['occupied_violations']} "
                f"score={out['grader_score']}",
                flush=True,
            )
        print(f"saved {output_path}")
        print(f"written={written} skipped={skipped}")
        return 0

    row, source_path = find_item(args)
    out = run_one(args, row, source_path, grade_fn=build_grader())
    output_path = Path(args.output)
    output_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
    print(f"saved {output_path}")
    print("violations", out["occupied_violations"])
    print("output", out["raw_output"])
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
