from __future__ import annotations

import argparse
import json
import math
import os
import random
import inspect
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Iterable

os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1")
os.environ.setdefault("UNSLOTH_DISABLE_FAST_GENERATION", "1")

try:
    import unsloth  # noqa: F401
except ImportError:
    unsloth = None

import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    set_seed,
)


DEFAULT_MODEL = "unsloth/gpt-oss-20b"
BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
HARMONY_STOP_TOKENS = ["<|return|>", "<|end|>", "<|call|>"]


@dataclass
class DataStats:
    train_rows: int
    eval_rows: int
    max_seq_len: int
    model_name: str


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="LoRA fine-tune GPT-OSS 20B on the KataGo large dataset")
    parser.add_argument("--model-name", default=DEFAULT_MODEL)
    parser.add_argument("--train-path", required=True)
    parser.add_argument("--eval-path", required=True)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--max-seq-len", type=int, default=2048)
    parser.add_argument("--max-train-examples", type=int, default=0)
    parser.add_argument("--max-eval-examples", type=int, default=0)
    parser.add_argument("--num-train-epochs", type=float, default=2.0)
    parser.add_argument("--max-steps", type=int, default=-1)
    parser.add_argument("--per-device-train-batch-size", type=int, default=1)
    parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
    parser.add_argument("--learning-rate", type=float, default=2e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--warmup-ratio", type=float, default=0.03)
    parser.add_argument("--logging-steps", type=int, default=10)
    parser.add_argument("--eval-steps", type=int, default=100)
    parser.add_argument("--save-steps", type=int, default=250)
    parser.add_argument("--lora-r", type=int, default=16)
    parser.add_argument("--lora-alpha", type=int, default=32)
    parser.add_argument("--lora-dropout", type=float, default=0.05)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--bf16", action="store_true", default=False)
    parser.add_argument("--no-4bit", action="store_true", default=False)
    parser.add_argument("--smoke-test", action="store_true", default=False)
    parser.add_argument("--generate-samples", type=int, default=3)
    parser.add_argument("--max-new-tokens", type=int, default=160)
    return parser.parse_args()


def read_jsonl(path: Path, limit: int = 0) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    with path.open() as f:
        for line in f:
            if not line.strip():
                continue
            rows.append(json.loads(line))
            if limit and len(rows) >= limit:
                break
    return rows


def coord_to_row_col(coord: str) -> tuple[int, int]:
    coord = coord.strip().upper()
    col = BOARD_COLUMNS.index(coord[0])
    row = 19 - int(coord[1:])
    return row, col


def row_to_matrix(row: dict[str, Any]) -> list[list[int]]:
    matrix = [[0 for _ in range(19)] for _ in range(19)]
    stones = row.get("stones") or {}
    for coord in stones.get("black") or []:
        try:
            r, c = coord_to_row_col(str(coord))
        except (ValueError, IndexError):
            continue
        if 0 <= r < 19 and 0 <= c < 19:
            matrix[r][c] = 1
    for coord in stones.get("white") or []:
        try:
            r, c = coord_to_row_col(str(coord))
        except (ValueError, IndexError):
            continue
        if 0 <= r < 19 and 0 <= c < 19:
            matrix[r][c] = -1
    return matrix


def matrix_for_prompt(matrix: list[list[int]]) -> str:
    return "\n".join(" ".join(f"{value:2d}" for value in row) for row in matrix)


def stone_list(row: dict[str, Any], color: str) -> str:
    stones = ((row.get("stones") or {}).get(color) or [])
    return ", ".join(str(coord) for coord in stones) if stones else "none"


def build_user_prompt(row: dict[str, Any]) -> str:
    to_move = {"B": "Black", "W": "White"}.get(str(row.get("to_move") or "").upper(), "Unknown")
    matrix = row_to_matrix(row)
    return f"""You are a Go review assistant.

The board is represented as a 19x19 matrix.
1 means a black stone, -1 means a white stone, and 0 means an empty neutral point.
Rows are board rows 19 down to 1. Columns are A through T with I omitted.
Side to move: {to_move}
Black stones: {stone_list(row, "black")}
White stones: {stone_list(row, "white")}

Board matrix:
{matrix_for_prompt(matrix)}

Explain this Go position concisely for a Go player. Use the coordinate lists to avoid misreading the matrix."""


def build_harmony_prompt(row: dict[str, Any]) -> str:
    system = """You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2026-05-02

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message."""
    developer = """# Instructions

You are a Go review assistant. Explain Go board positions from matrix inputs.
Use only the board matrix, coordinate lists, and side to move supplied by the user.
Write the user-facing answer in the final channel as 2-4 concise sentences.
Do not restate the full matrix."""
    user = build_user_prompt(row)
    return (
        f"<|start|>system<|message|>{system}<|end|>"
        f"<|start|>developer<|message|>{developer}<|end|>"
        f"<|start|>user<|message|>{user}<|end|>"
        "<|start|>assistant<|channel|>final<|message|>"
    )


def build_messages(row: dict[str, Any]) -> tuple[str, str]:
    prompt = build_harmony_prompt(row)
    response = str(row.get("rationale_text") or "").strip()
    if not response:
        response = "No useful review comment is available for this position."
    return prompt, response


class CausalPromptDataset(Dataset):
    def __init__(self, rows: list[dict[str, Any]], tokenizer: Any, max_seq_len: int):
        self.rows = rows
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, index: int) -> dict[str, list[int]]:
        prompt, response = build_messages(self.rows[index])
        full_text = prompt + response + "<|return|>"
        prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
        encoded = self.tokenizer(
            full_text,
            add_special_tokens=False,
            truncation=True,
            max_length=self.max_seq_len,
        )
        input_ids = encoded.input_ids
        labels = input_ids.copy()
        prompt_len = min(len(prompt_ids), len(labels))
        labels[:prompt_len] = [-100] * prompt_len
        return {"input_ids": input_ids, "attention_mask": encoded.attention_mask, "labels": labels}


class DataCollatorForCausalPrompt:
    def __init__(self, tokenizer: Any):
        self.tokenizer = tokenizer

    def __call__(self, features: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]:
        max_len = max(len(item["input_ids"]) for item in features)
        batch = {"input_ids": [], "attention_mask": [], "labels": []}
        for item in features:
            pad_len = max_len - len(item["input_ids"])
            batch["input_ids"].append(item["input_ids"] + [self.tokenizer.pad_token_id] * pad_len)
            batch["attention_mask"].append(item["attention_mask"] + [0] * pad_len)
            batch["labels"].append(item["labels"] + [-100] * pad_len)
        return {key: torch.tensor(value, dtype=torch.long) for key, value in batch.items()}


def lora_target_modules(model: torch.nn.Module) -> list[str]:
    preferred = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    present = {name.split(".")[-1] for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)}
    targets = sorted(preferred & present)
    if targets:
        return targets
    return sorted(present)


def load_model_and_tokenizer(args: argparse.Namespace) -> tuple[Any, Any]:
    dtype = torch.bfloat16 if args.bf16 or torch.cuda.is_bf16_supported() else torch.float16
    try:
        from unsloth import FastLanguageModel

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=args.model_name,
            max_seq_length=args.max_seq_len,
            dtype=None,
            load_in_4bit=not args.no_4bit,
            full_finetuning=False,
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model = FastLanguageModel.get_peft_model(
            model,
            r=args.lora_r,
            target_modules=lora_target_modules(model),
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            bias="none",
            use_gradient_checkpointing="unsloth",
            random_state=args.seed,
        )
        model.print_trainable_parameters()
        return model, tokenizer
    except Exception as exc:
        print(f"[WARN] Unsloth load failed or is unavailable; falling back to HuggingFace PEFT. Reason: {exc}")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    quantization_config = None
    if not args.no_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
    )
    model.config.use_cache = False
    if not args.no_4bit:
        model = prepare_model_for_kbit_training(model)

    targets = lora_target_modules(model)
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=targets,
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    return model, tokenizer


def write_run_metadata(output_dir: Path, args: argparse.Namespace, stats: DataStats) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    metadata = {"args": vars(args), "data": asdict(stats)}
    (output_dir / "run_metadata.json").write_text(json.dumps(metadata, indent=2))


def token_id(tokenizer: Any, token: str) -> int | None:
    ids = tokenizer.encode(token, add_special_tokens=False)
    return ids[0] if len(ids) == 1 else None


def extract_final_message(generated: str) -> str:
    if "<|channel|>final<|message|>" in generated:
        generated = generated.rsplit("<|channel|>final<|message|>", 1)[-1]
    elif "<|message|>" in generated:
        generated = generated.rsplit("<|message|>", 1)[-1]
    for stop in HARMONY_STOP_TOKENS:
        if stop in generated:
            generated = generated.split(stop, 1)[0]
    for token in ["<|start|>", "<|channel|>analysis", "<|channel|>commentary", "<|channel|>final"]:
        generated = generated.replace(token, "")
    return generated.strip()


def generate_eval_samples(
    model: Any,
    tokenizer: Any,
    rows: list[dict[str, Any]],
    output_dir: Path,
    num_samples: int,
    max_new_tokens: int,
) -> None:
    if num_samples <= 0:
        return
    model.eval()
    stop_ids = [token_id(tokenizer, token) for token in HARMONY_STOP_TOKENS]
    eos_ids = [idx for idx in stop_ids if idx is not None]
    if tokenizer.eos_token_id is not None:
        eos_ids.append(tokenizer.eos_token_id)

    samples = []
    markdown_parts = ["# Fine-Tune Eval Generations", ""]
    for index, row in enumerate(rows[:num_samples]):
        prompt = build_harmony_prompt(row)
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
        with torch.inference_mode():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                use_cache=True,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=eos_ids or tokenizer.eos_token_id,
            )
        generated_ids = output_ids[0][inputs["input_ids"].shape[1] :]
        generated = tokenizer.decode(generated_ids, skip_special_tokens=False)
        model_text = extract_final_message(generated)
        gold_text = str(row.get("rationale_text") or "").strip()
        sample = {
            "index": index,
            "id": row.get("id"),
            "to_move": row.get("to_move"),
            "prompt": prompt,
            "model_explanation": model_text,
            "gold_explanation": gold_text,
        }
        samples.append(sample)
        print(json.dumps(sample, indent=2), flush=True)
        markdown_parts.extend(
            [
                f"## Sample {index}: {row.get('id')}",
                "",
                f"- to_move: {row.get('to_move')}",
                "",
                "### Model",
                "",
                model_text or "<empty>",
                "",
                "### Gold",
                "",
                gold_text or "<empty>",
                "",
            ]
        )

    (output_dir / "eval_generations.json").write_text(json.dumps({"samples": samples}, indent=2))
    (output_dir / "eval_generations.jsonl").write_text("".join(json.dumps(sample) + "\n" for sample in samples))
    (output_dir / "eval_generations.md").write_text("\n".join(markdown_parts))


def main() -> None:
    args = parse_args()
    if args.smoke_test:
        args.max_train_examples = args.max_train_examples or 8
        args.max_eval_examples = args.max_eval_examples or 4
        args.max_steps = 1 if args.max_steps < 0 else args.max_steps
        args.logging_steps = 1
        args.eval_steps = 1
        args.save_steps = 1

    set_seed(args.seed)
    random.seed(args.seed)

    if args.bf16 and not torch.cuda.is_available():
        raise RuntimeError("--bf16 was requested but CUDA is not available")
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Run this script in the Modal GPU function.")

    train_rows = read_jsonl(Path(args.train_path), args.max_train_examples)
    eval_rows = read_jsonl(Path(args.eval_path), args.max_eval_examples)
    train_rows = [row for row in train_rows if int(row.get("board_size") or 19) == 19]
    eval_rows = [row for row in eval_rows if int(row.get("board_size") or 19) == 19]
    if not train_rows:
        raise RuntimeError(f"No train rows found at {args.train_path}")
    if not eval_rows:
        raise RuntimeError(f"No eval rows found at {args.eval_path}")

    print(f"[GPU] {torch.cuda.get_device_name(0)}")
    print(f"[DATA] train_rows={len(train_rows)} eval_rows={len(eval_rows)}")
    model, tokenizer = load_model_and_tokenizer(args)
    train_dataset = CausalPromptDataset(train_rows, tokenizer, args.max_seq_len)
    eval_dataset = CausalPromptDataset(eval_rows, tokenizer, args.max_seq_len)

    output_dir = Path(args.output_dir)
    stats = DataStats(
        train_rows=len(train_rows),
        eval_rows=len(eval_rows),
        max_seq_len=args.max_seq_len,
        model_name=args.model_name,
    )
    write_run_metadata(output_dir, args, stats)

    training_kwargs = {
        "output_dir": str(output_dir),
        "overwrite_output_dir": True,
        "per_device_train_batch_size": args.per_device_train_batch_size,
        "per_device_eval_batch_size": args.per_device_eval_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "warmup_ratio": args.warmup_ratio,
        "num_train_epochs": args.num_train_epochs,
        "max_steps": args.max_steps,
        "logging_steps": args.logging_steps,
        "eval_strategy": "steps",
        "eval_steps": args.eval_steps,
        "save_steps": args.save_steps,
        "save_total_limit": 2,
        "bf16": args.bf16 or torch.cuda.is_bf16_supported(),
        "fp16": not (args.bf16 or torch.cuda.is_bf16_supported()),
        "gradient_checkpointing": True,
        "report_to": [],
        "remove_unused_columns": False,
    }
    valid_training_args = set(inspect.signature(TrainingArguments).parameters)
    training_args = TrainingArguments(
        **{key: value for key, value in training_kwargs.items() if key in valid_training_args}
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForCausalPrompt(tokenizer),
    )
    train_result = trainer.train()
    eval_metrics = trainer.evaluate()

    trainer.save_model(str(output_dir / "adapter"))
    tokenizer.save_pretrained(str(output_dir / "adapter"))
    metrics = {
        "train": train_result.metrics,
        "eval": eval_metrics,
        "perplexity": math.exp(eval_metrics["eval_loss"]) if eval_metrics.get("eval_loss", 100) < 20 else None,
    }
    (output_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
    generate_eval_samples(model, tokenizer, eval_rows, output_dir, args.generate_samples, args.max_new_tokens)
    print(json.dumps(metrics, indent=2))


if __name__ == "__main__":
    main()
