#!/usr/bin/env python3
"""Run many parallel o4-mini samples per already-tested RFT row and grade them."""

from __future__ import annotations

import argparse
import concurrent.futures
import json
import os
import threading
import time
from pathlib import Path
from typing import Any

from openai import OpenAI

from rft_katago_grader import GRADER_SOURCE, response_format
from run_rft_inference_with_repair import (
    call_model,
    find_legality_violations,
    read_jsonl,
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--rows-file", default="openairft/rft_katago/o4_mini_repair_sample1000.jsonl")
    parser.add_argument("--train-path", default="openairft/rft_katago/katago_rft_train.jsonl")
    parser.add_argument("--validation-path", default="openairft/rft_katago/katago_rft_validation.jsonl")
    parser.add_argument("--output", default="openairft/rft_katago/o4_mini_222x20_samples.jsonl")
    parser.add_argument("--summary-output", default="openairft/rft_katago/o4_mini_222x20_summary.json")
    parser.add_argument("--samples-per-row", type=int, default=20)
    parser.add_argument("--max-workers", type=int, default=20)
    parser.add_argument("--model", default="o4-mini-2025-04-16")
    parser.add_argument("--reasoning-effort", default="medium")
    parser.add_argument("--max-completion-tokens", type=int, default=10000)
    parser.add_argument("--max-repairs", type=int, default=2)
    return parser.parse_args()


def item_key(meta: dict[str, Any]) -> str:
    return f"{meta.get('game_file')}::{meta.get('move_number')}::{meta.get('id')}"


def build_item_index(paths: list[str]) -> dict[str, dict[str, Any]]:
    index = {}
    for path in paths:
        for row in read_jsonl(Path(path)):
            key = item_key(row.get("metadata", {}))
            row["_source_path"] = path
            index[key] = row
    return index


def requested_keys(rows_file: Path) -> list[str]:
    keys = []
    seen = set()
    for row in read_jsonl(rows_file):
        key = item_key(row.get("metadata", {}))
        if key not in seen:
            keys.append(key)
            seen.add(key)
    return keys


def completed_sample_keys(path: Path) -> set[tuple[str, int]]:
    if not path.exists():
        return set()
    done = set()
    with path.open(encoding="utf-8") as handle:
        for line in handle:
            if not line.strip():
                continue
            try:
                row = json.loads(line)
            except Exception:
                continue
            done.add((row["row_key"], int(row["sample_index"])))
    return done


def build_grader():
    namespace: dict[str, Any] = {}
    exec(GRADER_SOURCE, namespace)
    return namespace["grade"]


def run_sample(
    row_key: str,
    row: dict[str, Any],
    sample_index: int,
    args: argparse.Namespace,
) -> dict[str, Any]:
    client = OpenAI()
    grade_fn = build_grader()
    prompt = row["messages"][0]["content"]
    messages = row["messages"]
    attempts = []
    final_content = ""
    final_parsed = None
    final_usage = None
    final_violations = []

    for attempt in range(args.max_repairs + 1):
        content, parsed, usage = call_model(
            client,
            model=args.model,
            messages=messages,
            reasoning_effort=args.reasoning_effort,
            max_completion_tokens=args.max_completion_tokens,
        )
        violations = ["output did not parse as JSON"] if parsed is None else find_legality_violations(parsed, row)
        attempts.append(
            {
                "attempt": attempt + 1,
                "raw_output": content,
                "parsed_output": parsed,
                "legality_violations": violations,
                "usage": usage.model_dump() if usage else None,
            }
        )
        final_content = content
        final_parsed = parsed
        final_usage = usage
        final_violations = violations
        if not violations:
            break
        repair_prompt = (
            prompt
            + "\n\nYour previous JSON output was illegal because:\n"
            + "\n".join(f"- {violation}" for violation in violations)
            + "\n\nPrevious output:\n"
            + content
            + "\n\nRepair the JSON. Every move must be legal at the point it is played; "
            "playing on a point is allowed only if it has become empty after a capture. "
            "Return strict JSON only with the same fields."
        )
        messages = [{"role": "user", "content": repair_prompt}]

    score = grade_fn(final_parsed, row) if final_parsed is not None else 0.0
    meta = row["metadata"]
    return {
        "row_key": row_key,
        "sample_index": sample_index,
        "metadata": meta,
        "reference": row["reference"],
        "model": args.model,
        "reasoning_effort": args.reasoning_effort,
        "max_completion_tokens": args.max_completion_tokens,
        "grader_score": score,
        "legality_violations": final_violations,
        "parsed_output": final_parsed,
        "raw_output": final_content,
        "attempt_count": len(attempts),
        "attempts": attempts,
        "usage": final_usage.model_dump() if final_usage else None,
    }


def summarize(output_path: Path, summary_path: Path) -> dict[str, Any]:
    rows = read_jsonl(output_path) if output_path.exists() else []
    by_key: dict[str, list[dict[str, Any]]] = {}
    for row in rows:
        by_key.setdefault(row["row_key"], []).append(row)

    row_summaries = []
    for key, samples in by_key.items():
        scores = [float(sample.get("grader_score", 0.0)) for sample in samples]
        meta = samples[0]["metadata"]
        row_summaries.append(
            {
                "row_key": key,
                "metadata": meta,
                "n": len(samples),
                "max_score": max(scores),
                "mean_score": sum(scores) / len(scores),
                "count_gt_0_3": sum(score > 0.3 for score in scores),
                "count_gt_0_4": sum(score > 0.4 for score in scores),
                "count_gt_0_5": sum(score > 0.5 for score in scores),
                "count_gt_0_6": sum(score > 0.6 for score in scores),
                "final_legality_failures": sum(bool(sample.get("legality_violations")) for sample in samples),
            }
        )
    max_scores = [row["max_score"] for row in row_summaries]
    summary = {
        "samples": len(rows),
        "rows": len(row_summaries),
        "rows_complete_20": sum(row["n"] >= 20 for row in row_summaries),
        "mean_best_score": sum(max_scores) / len(max_scores) if max_scores else 0.0,
        "rows_best_gt_0_3": sum(score > 0.3 for score in max_scores),
        "rows_best_gt_0_4": sum(score > 0.4 for score in max_scores),
        "rows_best_gt_0_5": sum(score > 0.5 for score in max_scores),
        "rows_best_gt_0_6": sum(score > 0.6 for score in max_scores),
        "row_summaries": sorted(row_summaries, key=lambda row: row["max_score"], reverse=True),
    }
    summary_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
    return summary


def main() -> int:
    args = parse_args()
    output_path = Path(args.output)
    summary_path = Path(args.summary_output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    index = build_item_index([args.validation_path, args.train_path])
    keys = requested_keys(Path(args.rows_file))
    done = completed_sample_keys(output_path)

    tasks = []
    for key in keys:
        row = index.get(key)
        if row is None:
            continue
        for sample_index in range(args.samples_per_row):
            if (key, sample_index) not in done:
                tasks.append((key, row, sample_index))

    print(f"rows={len(keys)} pending_samples={len(tasks)} already_done={len(done)}", flush=True)
    lock = threading.Lock()
    completed = 0
    started = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        futures = [
            executor.submit(run_sample, key, row, sample_index, args)
            for key, row, sample_index in tasks
        ]
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            with lock:
                with output_path.open("a", encoding="utf-8") as handle:
                    handle.write(json.dumps(result, ensure_ascii=False) + "\n")
                completed += 1
                if completed % 10 == 0 or completed == len(tasks):
                    elapsed = max(1.0, time.time() - started)
                    print(
                        f"completed={completed}/{len(tasks)} "
                        f"rate={completed / elapsed:.3f}/s "
                        f"score={result['grader_score']}",
                        flush=True,
                    )
                    summarize(output_path, summary_path)

    summary = summarize(output_path, summary_path)
    print(json.dumps({k: v for k, v in summary.items() if k != "row_summaries"}, indent=2))
    print(f"saved {output_path}")
    print(f"saved {summary_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
