#!/usr/bin/env python3
"""Run OpenAI o4-mini on random Go review positions."""

from __future__ import annotations

import argparse
import json
import os
import random
import re
from pathlib import Path
from typing import Any

from openai import OpenAI


MODEL = "o4-mini-2025-04-16"
INSTRUCTION = (
    "Explain this Go position. Be detailed and specific in your explanations, "
    "and keep your explanation under 80 words. Then teach it to a student. "
    "Aim for 55 to 75 words."
)

GO_COLUMNS = "ABCDEFGHJKLMNOPQRST"


def load_rows(path: Path) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    with path.open(encoding="utf-8") as handle:
        for line in handle:
            if line.strip():
                rows.append(json.loads(line))
    return rows


def board_state_to_matrix(board_state: str, board_size: int = 19) -> list[list[int]]:
    """Convert B1:Q16,W2:D3,... into a 19x19 matrix.

    Rows are printed from board row 19 down to 1. Columns are A through T,
    skipping I. Black stones are 1, white stones are -1, and empty points are 0.
    """
    matrix = [[0 for _ in range(board_size)] for _ in range(board_size)]
    if not board_state:
        return matrix

    coord_to_col = {col: idx for idx, col in enumerate(GO_COLUMNS[:board_size])}
    for item in board_state.split(","):
        item = item.strip()
        match = re.fullmatch(r"([BW])\d+:([A-HJ-T])(\d{1,2})", item)
        if not match:
            continue
        color, col_text, row_text = match.groups()
        row_number = int(row_text)
        if col_text not in coord_to_col or not (1 <= row_number <= board_size):
            continue
        matrix_row = board_size - row_number
        matrix_col = coord_to_col[col_text]
        matrix[matrix_row][matrix_col] = 1 if color == "B" else -1
    return matrix


def matrix_to_text(matrix: list[list[int]]) -> str:
    return "\n".join(" ".join(f"{value:2d}" for value in row) for row in matrix)


def build_prompt(row: dict[str, Any]) -> str:
    board_matrix_text = matrix_to_text(row["board_matrix"])
    return (
        f"{INSTRUCTION}\n\n"
        "Position format: 19 rows from board row 19 down to row 1, with columns "
        "A through T excluding I. Values: 1 = black, -1 = white, 0 = neutral.\n\n"
        f"Position matrix:\n{board_matrix_text}"
    )


def word_count(text: str) -> int:
    return len(re.findall(r"\S+", text))


def generate_explanation(client: OpenAI, prompt: str) -> str:
    response = client.chat.completions.create(
        model=MODEL,
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        max_completion_tokens=1000,
        reasoning_effort="low",
    )
    return (response.choices[0].message.content or "").strip()


def write_outputs(rows: list[dict[str, Any]], jsonl_path: Path, md_path: Path) -> None:
    with jsonl_path.open("w", encoding="utf-8") as handle:
        for row in rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")

    lines = ["# o4-mini Go Position Explanations", ""]
    for idx, row in enumerate(rows, start=1):
        lines.extend(
            [
                f"## {idx}. id={row['id']}",
                "",
                f"- game_file: `{row.get('game_file')}`",
                f"- move_number: `{row.get('move_number')}`",
                f"- reviewer_rank: `{row.get('reviewer_rank')}`",
                "",
                "### Position",
                "",
                "```text",
                str(row.get("board_matrix_text", "")),
                "```",
                "",
                "### o4-mini explanation",
                "",
                row["model_explanation"],
                "",
                "### Original review comment",
                "",
                str(row.get("comment", "")),
                "",
            ]
        )
    md_path.write_text("\n".join(lines), encoding="utf-8")


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="openairft/dan_reviews_1d_plus.jsonl")
    parser.add_argument("--output-jsonl", default="openairft/o4_mini_20_positions.jsonl")
    parser.add_argument("--output-md", default="openairft/o4_mini_20_positions.md")
    parser.add_argument("--n", type=int, default=20)
    parser.add_argument("--seed", type=int, default=20260503)
    args = parser.parse_args()

    if not os.environ.get("OPENAI_API_KEY"):
        raise RuntimeError("OPENAI_API_KEY is not set")

    input_path = Path(args.input)
    rows = load_rows(input_path)
    eligible = [row for row in rows if row.get("board_state")]
    if len(eligible) < args.n:
        raise RuntimeError(f"Only {len(eligible)} rows have board_state; need {args.n}")

    rng = random.Random(args.seed)
    sample = rng.sample(eligible, args.n)
    client = OpenAI()
    outputs: list[dict[str, Any]] = []

    for idx, row in enumerate(sample, start=1):
        row["board_matrix"] = board_state_to_matrix(str(row.get("board_state", "")))
        row["board_matrix_text"] = matrix_to_text(row["board_matrix"])
        prompt = build_prompt(row)
        print(f"Running {idx}/{args.n}: id={row['id']}", flush=True)
        explanation = ""
        for attempt in range(1, 4):
            retry_prompt = prompt
            if attempt > 1:
                retry_prompt += (
                    "\n\nYour previous answer was empty or exceeded 80 words. "
                    "Answer with one paragraph, 55 to 75 words, and no preamble."
                )
            explanation = generate_explanation(client, retry_prompt)
            if explanation and word_count(explanation) < 80:
                break
            print(
                f"Retrying id={row['id']} after attempt {attempt}: "
                f"{word_count(explanation)} words",
                flush=True,
            )
        outputs.append(
            {
                "id": row.get("id"),
                "game_file": row.get("game_file"),
                "move_number": row.get("move_number"),
                "reviewer_rank": row.get("reviewer_rank"),
                "board_matrix": row.get("board_matrix"),
                "board_matrix_text": row.get("board_matrix_text"),
                "prompt": prompt,
                "model": MODEL,
                "model_explanation": explanation.strip(),
                "model_explanation_word_count": word_count(explanation),
                "original_comment": row.get("comment"),
            }
        )

    write_outputs(outputs, Path(args.output_jsonl), Path(args.output_md))
    print(f"Saved {len(outputs)} rows to {args.output_jsonl}")
    print(f"Saved readable Markdown to {args.output_md}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
