"""Python grader source for KataGo PV reinforcement fine-tuning."""

GRADER_SOURCE = r'''
import json
import math
import re

TOP_MOVE_SCORES = {
    1: 1.00,
    2: 0.85,
    3: 0.70,
    4: 0.50,
    5: 0.50,
}

MOVE_RE = re.compile(r"^(?:PASS|[A-HJ-T](?:[1-9]|1[0-9]))$")


def _as_dict(sample):
    if isinstance(sample, dict):
        if isinstance(sample.get("output_json"), dict):
            return sample["output_json"]
        if isinstance(sample.get("json"), dict):
            return sample["json"]
        for key in ("output_text", "text", "content"):
            value = sample.get(key)
            if isinstance(value, str):
                try:
                    parsed = json.loads(value)
                    if isinstance(parsed, dict):
                        return parsed
                except Exception:
                    pass
        return sample
    if isinstance(sample, str):
        try:
            parsed = json.loads(sample)
            if isinstance(parsed, dict):
                return parsed
        except Exception:
            pass
    return {}


def _valid_move(move):
    return isinstance(move, str) and MOVE_RE.fullmatch(move.upper()) is not None


def _valid_pv(pv):
    return isinstance(pv, list) and all(_valid_move(move) for move in pv)


COLS = "ABCDEFGHJKLMNOPQRST"


def _coord_to_rc(move):
    move = str(move).upper()
    if move == "PASS":
        return None
    col = COLS.find(move[0])
    try:
        row_num = int(move[1:])
    except Exception:
        return None
    if col < 0 or row_num < 1 or row_num > 19:
        return None
    return 19 - row_num, col


def _neighbors(r, c):
    for rr, cc in ((r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)):
        if 0 <= rr < 19 and 0 <= cc < 19:
            yield rr, cc


def _group_and_liberties(board, r, c):
    color = board[r][c]
    stack = [(r, c)]
    seen = set()
    liberties = set()
    while stack:
        rr, cc = stack.pop()
        if (rr, cc) in seen:
            continue
        seen.add((rr, cc))
        for nr, nc in _neighbors(rr, cc):
            value = board[nr][nc]
            if value == 0:
                liberties.add((nr, nc))
            elif value == color and (nr, nc) not in seen:
                stack.append((nr, nc))
    return seen, liberties


def _board_from_occupied(item):
    board = [[0 for _ in range(19)] for _ in range(19)]
    for coord in item["reference"].get("black_points", []):
        rc = _coord_to_rc(coord)
        if rc is not None:
            board[rc[0]][rc[1]] = 1
    for coord in item["reference"].get("white_points", []):
        rc = _coord_to_rc(coord)
        if rc is not None:
            board[rc[0]][rc[1]] = -1
    return board


def _play_move(board, move, color):
    move = str(move).upper()
    if move == "PASS":
        return True
    rc = _coord_to_rc(move)
    if rc is None:
        return False
    r, c = rc
    if board[r][c] != 0:
        return False
    board[r][c] = color
    opponent = -color
    captured = False
    for nr, nc in _neighbors(r, c):
        if board[nr][nc] != opponent:
            continue
        group, liberties = _group_and_liberties(board, nr, nc)
        if not liberties:
            captured = True
            for gr, gc in group:
                board[gr][gc] = 0
    _, own_liberties = _group_and_liberties(board, r, c)
    if not own_liberties and not captured:
        board[r][c] = 0
        return False
    return True


def _illegal_pv_moves(pred_move, pred_pv, item):
    board = _board_from_occupied(item)
    color = 1 if item.get("metadata", {}).get("initial_player") == "B" else -1
    moves = [pred_move] + list(pred_pv[1:] if pred_pv and pred_pv[0] == pred_move else pred_pv)
    violations = []
    for idx, move in enumerate(moves):
        if not _play_move(board, move, color):
            violations.append(str(move).upper())
        color = -color
    return violations


def move_reward(pred_move, ref_top_moves):
    try:
        rank = ref_top_moves.index(pred_move) + 1
        return TOP_MOVE_SCORES.get(rank, 0.0)
    except ValueError:
        return 0.0


def pv_reward(pred_pv, ref_pv, alpha=0.85, max_len=12):
    T = min(len(pred_pv), len(ref_pv), max_len)
    if T == 0:
        return 0.0
    num = 0.0
    den = 0.0
    for t in range(T):
        w = alpha ** t
        den += w
        if pred_pv[t] == ref_pv[t]:
            num += w
        else:
            break
    return num / den if den > 0 else 0.0


def bounded_linear_reward(error, tol):
    return max(0.0, 1.0 - error / tol)


def grade(sample, item):
    """
    sample: model output dict or wrapper containing output_json/output_text
    item: reference dict with precomputed KataGo values
    """
    sample = _as_dict(sample)
    try:
        explanation = str(sample["explanation"])
        pred_move = str(sample["best_move"]).upper()
        pred_pv = [str(move).upper() for move in sample["pv_top1"]]
        pred_wr = float(sample["winrate_black"])
        pred_score = float(sample["score_lead_black"])
    except Exception:
        return 0.0

    ref_top_moves = [str(move).upper() for move in item["reference"]["top_moves"]]
    ref_pv = [str(move).upper() for move in item["reference"]["pv_top1"]]
    ref_wr = float(item["reference"]["winrate_black"])
    ref_score = float(item["reference"]["score_lead_black"])

    if not _valid_move(pred_move) or not _valid_pv(pred_pv):
        return 0.0
    if _illegal_pv_moves(pred_move, pred_pv, item):
        return 0.0

    explanation_words = len(re.findall(r"\S+", explanation))
    if explanation_words == 0:
        return 0.0
    r_format = 1.0 if explanation_words <= 150 else 0.5

    r_move = move_reward(pred_move, ref_top_moves)
    r_pv = pv_reward(pred_pv, ref_pv)
    r_wr = bounded_linear_reward(abs(pred_wr - ref_wr), tol=15.0)
    r_score = bounded_linear_reward(abs(pred_score - ref_score), tol=10.0)

    reward = (
        0.35 * r_move +
        0.30 * r_pv +
        0.15 * r_wr +
        0.15 * r_score +
        0.05 * r_format
    )
    return float(max(0.0, min(1.0, reward)))
'''


def grader_config() -> dict:
    return {
        "type": "python",
        "name": "katago_pv_reward",
        "source": GRADER_SOURCE,
    }


def response_format() -> dict:
    return {
        "type": "json_schema",
        "json_schema": {
            "name": "katago_pv_prediction",
            "strict": True,
            "schema": {
                "type": "object",
                "properties": {
                    "best_move": {
                        "type": "string",
                        "description": "Predicted best move in GTP coordinate form, e.g. R15, or PASS.",
                    },
                    "explanation": {
                        "type": "string",
                        "description": "Brief explanation, max 150 words, of key position features and why the best move is correct.",
                    },
                    "pv_top1": {
                        "type": "array",
                        "items": {"type": "string"},
                        "minItems": 1,
                        "maxItems": 12,
                        "description": "Predicted principal variation for the best move, up to 12 plies.",
                    },
                    "winrate_black": {
                        "type": "number",
                        "description": "Black winrate as a percentage from 0 to 100.",
                    },
                    "score_lead_black": {
                        "type": "number",
                        "description": "KataGo score lead from Black's perspective, in points.",
                    },
                },
                "required": [
                    "explanation",
                    "best_move",
                    "pv_top1",
                    "winrate_black",
                    "score_lead_black",
                ],
                "additionalProperties": False,
            },
        },
    }
