from __future__ import annotations

import argparse
import json
import math
import os
import random
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Iterable

import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    set_seed,
)


DEFAULT_MODEL = "unsloth/SmolLM3-3B-Base"
SYSTEM_PROMPT = (
    "You are a Go review assistant. Given a board position and KataGo analysis tags, "
    "write concise review commentary for the position."
)


@dataclass
class DataStats:
    train_rows: int
    eval_rows: int
    max_seq_len: int
    model_name: str


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="LoRA fine-tune SmolLM3-3B-Base on the KataGo large dataset")
    parser.add_argument("--model-name", default=DEFAULT_MODEL)
    parser.add_argument("--train-path", required=True)
    parser.add_argument("--eval-path", required=True)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--max-seq-len", type=int, default=1024)
    parser.add_argument("--max-train-examples", type=int, default=0)
    parser.add_argument("--max-eval-examples", type=int, default=0)
    parser.add_argument("--num-train-epochs", type=float, default=2.0)
    parser.add_argument("--max-steps", type=int, default=-1)
    parser.add_argument("--per-device-train-batch-size", type=int, default=1)
    parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
    parser.add_argument("--learning-rate", type=float, default=2e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--warmup-ratio", type=float, default=0.03)
    parser.add_argument("--logging-steps", type=int, default=10)
    parser.add_argument("--eval-steps", type=int, default=100)
    parser.add_argument("--save-steps", type=int, default=250)
    parser.add_argument("--lora-r", type=int, default=16)
    parser.add_argument("--lora-alpha", type=int, default=32)
    parser.add_argument("--lora-dropout", type=float, default=0.05)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--bf16", action="store_true", default=False)
    parser.add_argument("--no-4bit", action="store_true", default=False)
    parser.add_argument("--smoke-test", action="store_true", default=False)
    return parser.parse_args()


def read_jsonl(path: Path, limit: int = 0) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    with path.open() as f:
        for line in f:
            if not line.strip():
                continue
            rows.append(json.loads(line))
            if limit and len(rows) >= limit:
                break
    return rows


def format_position(row: dict[str, Any]) -> str:
    stones = row.get("stones") or {}
    black = " ".join(stones.get("black") or [])
    white = " ".join(stones.get("white") or [])
    tags = [
        f"id={row.get('id', '')}",
        f"turn={row.get('turn_number', '')}",
        f"board_size={row.get('board_size', '')}",
        f"to_move={row.get('to_move', '')}",
        f"komi={row.get('komi', '')}",
        f"rules={row.get('rules', '')}",
        f"win_prob={row.get('win_prob', '')}",
        f"score_lead={row.get('score_lead', '')}",
        f"score_bin={row.get('score_lead_bin', '')}",
        f"phase={row.get('phase_estimate', '')}",
        f"control={row.get('main_control_region', '')}",
        f"contest={row.get('main_contested_region', '')}",
        f"best_move={row.get('best_move', '')}",
        f"best_move_region={row.get('best_move_region', '')}",
        f"urgency={row.get('move_urgency', '')}",
        f"surprise={row.get('search_surprise', '')}",
    ]
    position_tokens = " ".join(row.get("position_tokens") or [])
    return "\n".join(
        [
            "KataGo tags: " + " | ".join(tags),
            "Position tokens: " + position_tokens,
            "Black stones: " + black,
            "White stones: " + white,
        ]
    )


def build_messages(row: dict[str, Any]) -> tuple[str, str]:
    prompt = "\n\n".join(
        [
            f"<|system|>\n{SYSTEM_PROMPT}",
            f"<|user|>\n{format_position(row)}",
            "<|assistant|>\n",
        ]
    )
    response = str(row.get("rationale_text") or "").strip()
    if not response:
        response = "No useful review comment is available for this position."
    return prompt, response


class CausalPromptDataset(Dataset):
    def __init__(self, rows: list[dict[str, Any]], tokenizer: Any, max_seq_len: int):
        self.rows = rows
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, index: int) -> dict[str, list[int]]:
        prompt, response = build_messages(self.rows[index])
        full_text = prompt + response + self.tokenizer.eos_token
        prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
        encoded = self.tokenizer(
            full_text,
            add_special_tokens=False,
            truncation=True,
            max_length=self.max_seq_len,
        )
        input_ids = encoded.input_ids
        labels = input_ids.copy()
        prompt_len = min(len(prompt_ids), len(labels))
        labels[:prompt_len] = [-100] * prompt_len
        return {"input_ids": input_ids, "attention_mask": encoded.attention_mask, "labels": labels}


class DataCollatorForCausalPrompt:
    def __init__(self, tokenizer: Any):
        self.tokenizer = tokenizer

    def __call__(self, features: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]:
        max_len = max(len(item["input_ids"]) for item in features)
        batch = {"input_ids": [], "attention_mask": [], "labels": []}
        for item in features:
            pad_len = max_len - len(item["input_ids"])
            batch["input_ids"].append(item["input_ids"] + [self.tokenizer.pad_token_id] * pad_len)
            batch["attention_mask"].append(item["attention_mask"] + [0] * pad_len)
            batch["labels"].append(item["labels"] + [-100] * pad_len)
        return {key: torch.tensor(value, dtype=torch.long) for key, value in batch.items()}


def lora_target_modules(model: torch.nn.Module) -> list[str]:
    preferred = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    present = {name.split(".")[-1] for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)}
    targets = sorted(preferred & present)
    if targets:
        return targets
    return sorted(present)


def load_model_and_tokenizer(args: argparse.Namespace) -> tuple[Any, Any]:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dtype = torch.bfloat16 if args.bf16 or torch.cuda.is_bf16_supported() else torch.float16
    quantization_config = None
    if not args.no_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
    )
    model.config.use_cache = False
    if not args.no_4bit:
        model = prepare_model_for_kbit_training(model)

    targets = lora_target_modules(model)
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=targets,
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    return model, tokenizer


def write_run_metadata(output_dir: Path, args: argparse.Namespace, stats: DataStats) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    metadata = {"args": vars(args), "data": asdict(stats)}
    (output_dir / "run_metadata.json").write_text(json.dumps(metadata, indent=2))


def main() -> None:
    args = parse_args()
    if args.smoke_test:
        args.max_train_examples = args.max_train_examples or 8
        args.max_eval_examples = args.max_eval_examples or 4
        args.max_steps = 1 if args.max_steps < 0 else args.max_steps
        args.logging_steps = 1
        args.eval_steps = 1
        args.save_steps = 1

    set_seed(args.seed)
    random.seed(args.seed)

    if args.bf16 and not torch.cuda.is_available():
        raise RuntimeError("--bf16 was requested but CUDA is not available")
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Run this script in the Modal GPU function.")

    train_rows = read_jsonl(Path(args.train_path), args.max_train_examples)
    eval_rows = read_jsonl(Path(args.eval_path), args.max_eval_examples)
    if not train_rows:
        raise RuntimeError(f"No train rows found at {args.train_path}")
    if not eval_rows:
        raise RuntimeError(f"No eval rows found at {args.eval_path}")

    print(f"[GPU] {torch.cuda.get_device_name(0)}")
    print(f"[DATA] train_rows={len(train_rows)} eval_rows={len(eval_rows)}")
    model, tokenizer = load_model_and_tokenizer(args)
    train_dataset = CausalPromptDataset(train_rows, tokenizer, args.max_seq_len)
    eval_dataset = CausalPromptDataset(eval_rows, tokenizer, args.max_seq_len)

    output_dir = Path(args.output_dir)
    stats = DataStats(
        train_rows=len(train_rows),
        eval_rows=len(eval_rows),
        max_seq_len=args.max_seq_len,
        model_name=args.model_name,
    )
    write_run_metadata(output_dir, args, stats)

    training_args = TrainingArguments(
        output_dir=str(output_dir),
        overwrite_output_dir=True,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        num_train_epochs=args.num_train_epochs,
        max_steps=args.max_steps,
        logging_steps=args.logging_steps,
        eval_strategy="steps",
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        save_total_limit=2,
        bf16=args.bf16 or torch.cuda.is_bf16_supported(),
        fp16=not (args.bf16 or torch.cuda.is_bf16_supported()),
        gradient_checkpointing=True,
        report_to=[],
        remove_unused_columns=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForCausalPrompt(tokenizer),
    )
    train_result = trainer.train()
    eval_metrics = trainer.evaluate()

    trainer.save_model(str(output_dir / "adapter"))
    tokenizer.save_pretrained(str(output_dir / "adapter"))
    metrics = {
        "train": train_result.metrics,
        "eval": eval_metrics,
        "perplexity": math.exp(eval_metrics["eval_loss"]) if eval_metrics.get("eval_loss", 100) < 20 else None,
    }
    (output_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
    print(json.dumps(metrics, indent=2))


if __name__ == "__main__":
    main()
