from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Any

os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1")
os.environ.setdefault("UNSLOTH_DISABLE_FAST_GENERATION", "1")

import torch


MODEL_NAME = "unsloth/gpt-oss-20b"
BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
HARMONY_STOP_TOKENS = ["<|return|>", "<|end|>", "<|call|>"]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run simple GPT-OSS matrix inference on one KataGo dataset row.")
    parser.add_argument("--model-name", default=MODEL_NAME)
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--sample-index", type=int, default=0)
    parser.add_argument("--num-samples", type=int, default=1)
    parser.add_argument("--max-new-tokens", type=int, default=160)
    parser.add_argument("--max-seq-length", type=int, default=2048)
    parser.add_argument("--no-4bit", action="store_true")
    parser.add_argument("--fast-inference", action="store_true", help="Use Unsloth fast/compiled inference path.")
    return parser.parse_args()


def read_samples(path: Path, sample_index: int, num_samples: int) -> list[dict[str, Any]]:
    if sample_index < 0:
        raise ValueError("--sample-index must be non-negative")
    if num_samples <= 0:
        raise ValueError("--num-samples must be positive")
    samples: list[dict[str, Any]] = []
    with path.open() as f:
        seen = 0
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            if int(row.get("board_size") or 19) != 19:
                continue
            if seen >= sample_index:
                samples.append(row)
                if len(samples) >= num_samples:
                    return samples
            seen += 1
    raise IndexError(f"Found {len(samples)} samples starting at index {sample_index}; needed {num_samples}")


def coord_to_row_col(coord: str) -> tuple[int, int]:
    coord = coord.strip().upper()
    col = BOARD_COLUMNS.index(coord[0])
    row = 19 - int(coord[1:])
    return row, col


def row_to_matrix(row: dict[str, Any]) -> list[list[int]]:
    matrix = [[0 for _ in range(19)] for _ in range(19)]
    stones = row.get("stones") or {}
    for coord in stones.get("black") or []:
        r, c = coord_to_row_col(str(coord))
        if 0 <= r < 19 and 0 <= c < 19:
            matrix[r][c] = 1
    for coord in stones.get("white") or []:
        r, c = coord_to_row_col(str(coord))
        if 0 <= r < 19 and 0 <= c < 19:
            matrix[r][c] = -1
    return matrix


def matrix_for_prompt(matrix: list[list[int]]) -> str:
    return "\n".join(" ".join(f"{value:2d}" for value in row) for row in matrix)


def stone_list(row: dict[str, Any], color: str) -> str:
    stones = ((row.get("stones") or {}).get(color) or [])
    return ", ".join(str(coord) for coord in stones) if stones else "none"


def build_user_prompt(row: dict[str, Any], matrix: list[list[int]]) -> str:
    to_move = {"B": "Black", "W": "White"}.get(str(row.get("to_move") or "").upper(), "Unknown")
    return f"""You are a Go review assistant.

The board is represented as a 19x19 matrix.
1 means a black stone, -1 means a white stone, and 0 means an empty neutral point.
Rows are board rows 19 down to 1. Columns are A through T with I omitted.
Side to move: {to_move}
Black stones: {stone_list(row, "black")}
White stones: {stone_list(row, "white")}

Board matrix:
{matrix_for_prompt(matrix)}

Explain this Go position concisely for a Go player. Use the coordinate lists to avoid misreading the matrix."""


def build_harmony_prompt(row: dict[str, Any], matrix: list[list[int]]) -> str:
    system = """You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2026-05-02

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message."""
    developer = """# Instructions

You are a Go review assistant. Explain Go board positions from matrix inputs.
Use only the board matrix, coordinate lists, and side to move supplied by the user.
Write the user-facing answer in the final channel as 2-4 concise sentences.
Do not restate the full matrix."""
    user = build_user_prompt(row, matrix)
    return (
        f"<|start|>system<|message|>{system}<|end|>"
        f"<|start|>developer<|message|>{developer}<|end|>"
        f"<|start|>user<|message|>{user}<|end|>"
        "<|start|>assistant<|channel|>final<|message|>"
    )


def load_model(model_name: str, max_seq_length: int, load_in_4bit: bool, fast_inference: bool) -> tuple[Any, Any]:
    from unsloth import FastLanguageModel

    load_kwargs = {
        "model_name": model_name,
        "max_seq_length": max_seq_length,
        "dtype": None,
        "load_in_4bit": load_in_4bit,
    }
    if fast_inference:
        load_kwargs["fast_inference"] = True
    else:
        load_kwargs["fast_inference"] = False
    try:
        model, tokenizer = FastLanguageModel.from_pretrained(**load_kwargs)
    except TypeError:
        load_kwargs.pop("fast_inference", None)
        model, tokenizer = FastLanguageModel.from_pretrained(**load_kwargs)
    if fast_inference:
        FastLanguageModel.for_inference(model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer


def token_id(tokenizer: Any, token: str) -> int | None:
    ids = tokenizer.encode(token, add_special_tokens=False)
    return ids[0] if len(ids) == 1 else None


def generate(model: Any, tokenizer: Any, prompt: str, max_new_tokens: int) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    stop_ids = [
        token_id(tokenizer, token)
        for token in HARMONY_STOP_TOKENS
    ]
    eos_ids = [idx for idx in stop_ids if idx is not None]
    if tokenizer.eos_token_id is not None:
        eos_ids.append(tokenizer.eos_token_id)
    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_ids or tokenizer.eos_token_id,
        )
    generated_ids = output_ids[0][inputs["input_ids"].shape[1] :]
    generated = tokenizer.decode(generated_ids, skip_special_tokens=False)
    return extract_final_message(generated)


def extract_final_message(generated: str) -> str:
    if "<|channel|>final<|message|>" in generated:
        generated = generated.rsplit("<|channel|>final<|message|>", 1)[-1]
    elif "<|message|>" in generated:
        generated = generated.rsplit("<|message|>", 1)[-1]
    for stop in HARMONY_STOP_TOKENS:
        if stop in generated:
            generated = generated.split(stop, 1)[0]
    for token in ["<|start|>", "<|channel|>analysis", "<|channel|>commentary", "<|channel|>final"]:
        generated = generated.replace(token, "")
    return generated.strip()


def main() -> None:
    args = parse_args()
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Run this on a Modal GPU.")

    rows = read_samples(Path(args.data_path), args.sample_index, args.num_samples)

    model, tokenizer = load_model(
        model_name=args.model_name,
        max_seq_length=args.max_seq_length,
        load_in_4bit=not args.no_4bit,
        fast_inference=args.fast_inference,
    )

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    results = []
    markdown_parts = ["# GPT-OSS Matrix Inference", ""]
    for offset, row in enumerate(rows):
        current_index = args.sample_index + offset
        matrix = row_to_matrix(row)
        prompt = build_harmony_prompt(row, matrix)
        explanation = generate(model, tokenizer, prompt, args.max_new_tokens)
        result = {
            "model_name": args.model_name,
            "data_path": args.data_path,
            "sample_index": current_index,
            "id": row.get("id"),
            "to_move": row.get("to_move"),
            "matrix": matrix,
            "prompt": prompt,
            "model_explanation": explanation,
            "original_review_text": row.get("rationale_text"),
        }
        results.append(result)
        print(json.dumps(result, indent=2), flush=True)
        markdown_parts.extend(
            [
                f"## Sample {current_index}: {row.get('id')}",
                "",
                f"- to_move: {row.get('to_move')}",
                "",
                "```text",
                matrix_for_prompt(matrix),
                "```",
                "",
                "### Model Explanation",
                "",
                explanation or "<empty>",
                "",
                "### Original Review Text",
                "",
                str(row.get("rationale_text") or "<empty>"),
                "",
            ]
        )

    payload = {
        "model_name": args.model_name,
        "data_path": args.data_path,
        "sample_index": args.sample_index,
        "num_samples": args.num_samples,
        "results": results,
    }
    (output_dir / "matrix_inference_results.json").write_text(json.dumps(payload, indent=2))
    (output_dir / "matrix_inference_results.jsonl").write_text(
        "".join(json.dumps(result) + "\n" for result in results)
    )
    (output_dir / "matrix_inference_results.md").write_text("\n".join(markdown_parts))


if __name__ == "__main__":
    main()
