from __future__ import annotations

import argparse
import json
import random
from pathlib import Path
from typing import Any

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed

from finetune_smollm import DEFAULT_MODEL


BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
INFERENCE_SYSTEM_PROMPT = (
    "You are a Go review assistant. Given only a board position, write concise commentary "
    "about the position for a Go player."
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate SmolLM explanations for random KataGo positions")
    parser.add_argument("--model-name", default=DEFAULT_MODEL)
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--adapter-path", default="")
    parser.add_argument("--num-samples", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-input-len", type=int, default=1024)
    parser.add_argument("--max-new-tokens", type=int, default=160)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top-p", type=float, default=0.9)
    parser.add_argument("--no-4bit", action="store_true", default=False)
    return parser.parse_args()


def read_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open() as f:
        return [json.loads(line) for line in f if line.strip()]


def load_model(args: argparse.Namespace) -> tuple[Any, Any]:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    quantization_config = None
    if not args.no_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
    )
    if args.adapter_path:
        model = PeftModel.from_pretrained(model, args.adapter_path)
    model.eval()
    return model, tokenizer


def coord_to_row_col(coord: str, board_size: int) -> tuple[int, int]:
    coord = coord.strip().upper()
    col_label = coord[0]
    row_label = int(coord[1:])
    return board_size - row_label, BOARD_COLUMNS.index(col_label)


def render_board(row: dict[str, Any]) -> str:
    board_size = int(row.get("board_size") or 19)
    grid = [["." for _ in range(board_size)] for _ in range(board_size)]
    stones = row.get("stones") or {}
    for color, symbol in [("black", "X"), ("white", "O")]:
        for coord in stones.get(color) or []:
            try:
                r, c = coord_to_row_col(coord, board_size)
            except (ValueError, IndexError):
                continue
            if 0 <= r < board_size and 0 <= c < board_size:
                grid[r][c] = symbol

    columns = " ".join(BOARD_COLUMNS[:board_size])
    lines = [f"Columns: {columns}", "X = Black, O = White, . = empty", ""]
    for idx, values in enumerate(grid):
        row_number = board_size - idx
        lines.append(f"{row_number:>2} " + " ".join(values))
    return "\n".join(lines)


def build_board_only_prompt(row: dict[str, Any]) -> str:
    to_move = str(row.get("to_move") or "").upper()
    to_move_text = {"B": "Black", "W": "White"}.get(to_move, "Unknown")
    return "\n\n".join(
        [
            f"<|system|>\n{INFERENCE_SYSTEM_PROMPT}",
            (
                "<|user|>\n"
                "Write commentary about this Go position. Use only the board position below; "
                "do not assume any engine analysis is available.\n\n"
                f"Board size: {row.get('board_size', '')}\n"
                f"To move: {to_move_text}\n\n"
                f"{render_board(row)}"
            ),
            "<|assistant|>\n",
        ]
    )


def generate_one(model: Any, tokenizer: Any, row: dict[str, Any], args: argparse.Namespace) -> str:
    prompt = build_board_only_prompt(row)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=args.max_input_len,
        add_special_tokens=False,
    ).to(model.device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=args.max_new_tokens,
            do_sample=True,
            temperature=args.temperature,
            top_p=args.top_p,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    generated = output_ids[0][inputs["input_ids"].shape[1] :]
    return tokenizer.decode(generated, skip_special_tokens=True).strip()


def write_markdown(rows: list[dict[str, Any]], output_path: Path) -> None:
    parts = [
        "# SmolLM KataGo Random Position Explanations",
        "",
        "The model prompt for each sample included only the rendered board position, board size, side to move, and a request for commentary.",
        "",
    ]
    for idx, row in enumerate(rows, 1):
        source = row["source"]
        to_move = str(source.get("to_move") or "").upper()
        to_move_text = {"B": "Black", "W": "White"}.get(to_move, "Unknown")
        parts.extend(
            [
                f"## {idx}. {source.get('id', '<unknown>')}",
                "",
                f"- board_size: {source.get('board_size', '')}",
                f"- to_move: {to_move_text}",
                "",
                "**Board given to model**",
                "",
                "```text",
                render_board(source),
                "```",
                "",
                "**Model explanation**",
                "",
                row["model_explanation"] or "<empty>",
                "",
                "**Original review text**",
                "",
                str(source.get("rationale_text") or "<empty>"),
                "",
            ]
        )
    output_path.write_text("\n".join(parts))


def main() -> None:
    args = parse_args()
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Run this script in the Modal GPU function.")

    set_seed(args.seed)
    rng = random.Random(args.seed)
    rows = read_jsonl(Path(args.data_path))
    sample_size = min(args.num_samples, len(rows))
    sampled = rng.sample(rows, sample_size)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    model, tokenizer = load_model(args)

    results: list[dict[str, Any]] = []
    for idx, row in enumerate(sampled, 1):
        explanation = generate_one(model, tokenizer, row, args)
        result = {"sample_index": idx, "source": row, "model_explanation": explanation}
        results.append(result)
        print(json.dumps({"sample_index": idx, "id": row.get("id"), "model_explanation": explanation}), flush=True)

    jsonl_path = output_dir / "random_100_explanations.jsonl"
    with jsonl_path.open("w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")
    write_markdown(results, output_dir / "random_100_explanations.md")
    (output_dir / "manifest.json").write_text(
        json.dumps(
            {
                "model_name": args.model_name,
                "adapter_path": args.adapter_path,
                "data_path": args.data_path,
                "num_samples": sample_size,
                "seed": args.seed,
                "jsonl": str(jsonl_path),
                "markdown": str(output_dir / "random_100_explanations.md"),
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()
