#!/usr/bin/env python3
"""Validate grader, upload KataGo RFT files, and optionally create an RFT job."""

from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any

import requests

from rft_katago_grader import grader_config, response_format

API_BASE = "https://api.openai.com/v1"
MODEL = "o4-mini-2025-04-16"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-file", default="openairft/rft_katago/katago_rft_train.jsonl")
    parser.add_argument("--validation-file", default="openairft/rft_katago/katago_rft_validation.jsonl")
    parser.add_argument("--out", default="openairft/rft_katago/rft_job_payloads.json")
    parser.add_argument("--suffix", default="katago-pv-rft")
    parser.add_argument("--validate-grader", action="store_true")
    parser.add_argument("--run-grader-smoke", action="store_true")
    parser.add_argument("--upload", action="store_true")
    parser.add_argument("--create-job", action="store_true")
    return parser.parse_args()


def headers() -> dict[str, str]:
    key = os.environ.get("OPENAI_API_KEY")
    if not key:
        raise RuntimeError("OPENAI_API_KEY is not set")
    return {"Authorization": f"Bearer {key}"}


def post_json(path: str, payload: dict[str, Any]) -> dict[str, Any]:
    response = requests.post(
        API_BASE + path,
        headers={**headers(), "Content-Type": "application/json"},
        json=payload,
        timeout=120,
    )
    try:
        data = response.json()
    except Exception:
        data = {"text": response.text}
    if response.status_code >= 400:
        raise RuntimeError(f"POST {path} failed {response.status_code}: {json.dumps(data, indent=2)}")
    return data


def upload_file(path: Path) -> dict[str, Any]:
    with path.open("rb") as handle:
        response = requests.post(
            API_BASE + "/files",
            headers=headers(),
            files={"file": (path.name, handle, "application/jsonl")},
            data={"purpose": "fine-tune"},
            timeout=300,
        )
    data = response.json()
    if response.status_code >= 400:
        raise RuntimeError(f"File upload failed {response.status_code}: {json.dumps(data, indent=2)}")
    return data


def load_first_item(path: Path) -> dict[str, Any]:
    with path.open(encoding="utf-8") as handle:
        return json.loads(next(line for line in handle if line.strip()))


def smoke_sample_for_item(item: dict[str, Any]) -> dict[str, Any]:
    ref = item["reference"]
    return {
        "best_move": ref["top_moves"][0],
        "pv_top1": ref["pv_top1"],
        "winrate_black": ref["winrate_black"],
        "score_lead_black": ref["score_lead_black"],
    }


def job_payload(train_file_id: str, validation_file_id: str, suffix: str) -> dict[str, Any]:
    return {
        "training_file": train_file_id,
        "validation_file": validation_file_id,
        "model": MODEL,
        "suffix": suffix,
        "method": {
            "type": "reinforcement",
            "reinforcement": {
                "grader": grader_config(),
                "response_format": response_format(),
            },
        },
    }


def main() -> int:
    args = parse_args()
    train_path = Path(args.train_file)
    validation_path = Path(args.validation_file)
    output_path = Path(args.out)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    artifacts: dict[str, Any] = {
        "grader": grader_config(),
        "response_format": response_format(),
        "model": MODEL,
        "train_file_path": str(train_path),
        "validation_file_path": str(validation_path),
    }

    if args.validate_grader:
        artifacts["grader_validation"] = post_json(
            "/fine_tuning/alpha/graders/validate",
            {"grader": grader_config()},
        )
        print("Validated grader")

    if args.run_grader_smoke:
        item = load_first_item(validation_path)
        sample = smoke_sample_for_item(item)
        artifacts["grader_smoke"] = post_json(
            "/fine_tuning/alpha/graders/run",
            {
                "grader": grader_config(),
                "item": item,
                "model_sample": json.dumps(sample),
            },
        )
        print("Ran grader smoke test")

    train_file = validation_file = None
    if args.upload or args.create_job:
        train_file = upload_file(train_path)
        validation_file = upload_file(validation_path)
        artifacts["uploaded_training_file"] = train_file
        artifacts["uploaded_validation_file"] = validation_file
        print(f"Uploaded train file: {train_file['id']}")
        print(f"Uploaded validation file: {validation_file['id']}")

    if args.create_job:
        assert train_file is not None and validation_file is not None
        payload = job_payload(train_file["id"], validation_file["id"], args.suffix)
        artifacts["job_request"] = payload
        artifacts["job"] = post_json("/fine_tuning/jobs", payload)
        print(f"Created fine-tuning job: {artifacts['job']['id']}")

    output_path.write_text(json.dumps(artifacts, indent=2), encoding="utf-8")
    print(f"Wrote payloads/results to {output_path}")
    if not (args.validate_grader or args.run_grader_smoke or args.upload or args.create_job):
        print(
            "No API actions requested. Add --validate-grader, --run-grader-smoke, "
            "--upload, or --create-job.",
            file=sys.stderr,
        )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
