from __future__ import annotations

import argparse
import json
import math
import os
import random
import shutil
from collections import Counter, defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any

os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1")
os.environ.setdefault("UNSLOTH_DISABLE_FAST_GENERATION", "1")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

try:
    from bec_mask import build_bec_mask, combine_with_padding_mask
except ImportError:
    from .bec_mask import build_bec_mask, combine_with_padding_mask


MODEL_NAME = "unsloth/gpt-oss-20b"
CLAIM_FIELDS = [
    "win_prob_bin",
    "score_lead_bin",
    "phase_estimate",
    "main_control_region",
    "main_contested_region",
    "global_contestedness",
    "best_move_region",
    "move_urgency",
    "search_surprise",
]
CLAIM_LABEL_VALUES = {
    "win_prob_bin": [str(i) for i in range(10)],
    "score_lead_bin": [
        "LEAD_CLOSE",
        "LEAD_N10_20",
        "LEAD_N20P",
        "LEAD_N5_10",
        "LEAD_P10_20",
        "LEAD_P20P",
        "LEAD_P5_10",
    ],
    "phase_estimate": [
        "PHASE_LATE",
        "PHASE_MID",
        "PHASE_OPENING",
        "PHASE_SETTLED",
    ],
    "main_control_region": [
        "CTRL_BOTTOM_B",
        "CTRL_BOTTOM_W",
        "CTRL_CENTER_B",
        "CTRL_CENTER_W",
        "CTRL_LEFT_B",
        "CTRL_LEFT_W",
        "CTRL_NONE",
        "CTRL_RIGHT_B",
        "CTRL_RIGHT_W",
        "CTRL_TOP_B",
        "CTRL_TOP_W",
    ],
    "main_contested_region": [
        "CONTEST_BOTTOM",
        "CONTEST_CENTER",
        "CONTEST_LEFT",
        "CONTEST_NONE",
        "CONTEST_RIGHT",
        "CONTEST_TOP",
    ],
    "global_contestedness": [
        "CONTEST_HIGH",
        "CONTEST_LOW",
        "CONTEST_MED",
    ],
    "best_move_region": [
        "BESTREG_BOTTOM",
        "BESTREG_CENTER",
        "BESTREG_LEFT",
        "BESTREG_PASS",
        "BESTREG_RIGHT",
        "BESTREG_TOP",
    ],
    "move_urgency": [
        "URG_HIGH",
        "URG_LOW",
        "URG_MED",
        "URG_MUST",
    ],
    "search_surprise": [
        "SURPRISE_HIGH",
        "SURPRISE_LOW",
        "SURPRISE_MED",
    ],
}
SPECIAL_TOKENS = [
    "<|board|>",
    "<|explanation|>",
    *[f"<|claim_{field}|>" for field in CLAIM_FIELDS],
]
IGNORE_INDEX = -100
BOARD_COLUMNS = "ABCDEFGHJKLMNOPQRSTUVWXYZ"


@dataclass
class TrainConfig:
    output_dir: str
    data_path: str | None = None
    train_data_path: str | None = None
    eval_data_path: str | None = None
    model_name: str = MODEL_NAME
    train_games: int = 100
    eval_games: int = 20
    max_train_positions: int = 0
    max_eval_positions: int = 0
    seed: int = 42
    max_seq_length: int = 1024
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 2
    gradient_accumulation_steps: int = 8
    num_train_epochs: int = 3
    learning_rate: float = 2e-4
    weight_decay: float = 0.0
    warmup_ratio: float = 0.03
    lambda_lm: float = 1.0
    lambda_claim: float = 1.0
    max_grad_norm: float = 1.0
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    load_in_4bit: bool = True
    bf16: bool = False
    log_every: int = 10
    generate_samples: int = 5
    max_new_tokens: int = 160


def parse_args() -> TrainConfig:
    parser = argparse.ArgumentParser(description="Train GPT-OSS 20B with LM + structured claim consistency losses.")
    parser.add_argument("--data-path", default=None, help="Single JSONL file to split by game.")
    parser.add_argument("--train-data-path", default=None, help="Explicit train JSONL file. Use with --eval-data-path.")
    parser.add_argument("--eval-data-path", default=None, help="Explicit eval JSONL file. Use with --train-data-path.")
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--model-name", default=MODEL_NAME)
    parser.add_argument("--train-games", type=int, default=100)
    parser.add_argument("--eval-games", type=int, default=20)
    parser.add_argument("--max-train-positions", type=int, default=0)
    parser.add_argument("--max-eval-positions", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-seq-length", type=int, default=1024)
    parser.add_argument("--per-device-train-batch-size", type=int, default=2)
    parser.add_argument("--per-device-eval-batch-size", type=int, default=2)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
    parser.add_argument("--num-train-epochs", type=int, default=3)
    parser.add_argument("--learning-rate", type=float, default=2e-4)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--warmup-ratio", type=float, default=0.03)
    parser.add_argument("--lambda-lm", type=float, default=1.0)
    parser.add_argument("--lambda-claim", type=float, default=1.0)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--lora-r", type=int, default=16)
    parser.add_argument("--lora-alpha", type=int, default=32)
    parser.add_argument("--lora-dropout", type=float, default=0.05)
    parser.add_argument("--no-4bit", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--log-every", type=int, default=10)
    parser.add_argument("--generate-samples", type=int, default=5)
    parser.add_argument("--max-new-tokens", type=int, default=160)
    args = parser.parse_args()
    if bool(args.train_data_path) != bool(args.eval_data_path):
        parser.error("--train-data-path and --eval-data-path must be provided together")
    if not args.data_path and not (args.train_data_path and args.eval_data_path):
        parser.error("provide either --data-path or both --train-data-path and --eval-data-path")
    return TrainConfig(
        output_dir=args.output_dir,
        data_path=args.data_path,
        train_data_path=args.train_data_path,
        eval_data_path=args.eval_data_path,
        model_name=args.model_name,
        train_games=args.train_games,
        eval_games=args.eval_games,
        max_train_positions=args.max_train_positions,
        max_eval_positions=args.max_eval_positions,
        seed=args.seed,
        max_seq_length=args.max_seq_length,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        lambda_lm=args.lambda_lm,
        lambda_claim=args.lambda_claim,
        max_grad_norm=args.max_grad_norm,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        load_in_4bit=not args.no_4bit,
        bf16=args.bf16,
        log_every=args.log_every,
        generate_samples=args.generate_samples,
        max_new_tokens=args.max_new_tokens,
    )


def read_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open() as f:
        return [json.loads(line) for line in f if line.strip()]


def split_by_game(rows: list[dict[str, Any]], train_games: int, eval_games: int, seed: int) -> tuple[list[dict], list[dict]]:
    by_game: dict[str, list[dict[str, Any]]] = defaultdict(list)
    for row in rows:
        by_game[str(row["game_id"])].append(row)
    game_ids = sorted(by_game)
    rng = random.Random(seed)
    rng.shuffle(game_ids)
    needed = train_games + eval_games
    if len(game_ids) < needed:
        raise ValueError(f"Need at least {needed} games, found {len(game_ids)}")
    train_ids = set(game_ids[:train_games])
    eval_ids = set(game_ids[train_games:needed])
    train_rows = [row for game_id in game_ids if game_id in train_ids for row in by_game[game_id]]
    eval_rows = [row for game_id in game_ids if game_id in eval_ids for row in by_game[game_id]]
    print(f"train games: {len(train_ids)}")
    print(f"eval games: {len(eval_ids)}")
    print(f"train positions: {len(train_rows)}")
    print(f"eval positions: {len(eval_rows)}")
    return train_rows, eval_rows


def load_train_eval_rows(cfg: TrainConfig) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    if cfg.train_data_path and cfg.eval_data_path:
        train_rows = filter_19x19_rows(read_jsonl(Path(cfg.train_data_path)))
        eval_rows = filter_19x19_rows(read_jsonl(Path(cfg.eval_data_path)))
        train_games = {str(row["game_id"]) for row in train_rows}
        eval_games = {str(row["game_id"]) for row in eval_rows}
        overlap = train_games & eval_games
        if overlap:
            raise ValueError(f"Train/eval game_id overlap detected for {len(overlap)} games")
        print(f"train games: {len(train_games)}")
        print(f"eval games: {len(eval_games)}")
        if cfg.max_train_positions:
            train_rows = train_rows[: cfg.max_train_positions]
        if cfg.max_eval_positions:
            eval_rows = eval_rows[: cfg.max_eval_positions]
        print(f"train positions: {len(train_rows)}")
        print(f"eval positions: {len(eval_rows)}")
        return train_rows, eval_rows

    if not cfg.data_path:
        raise ValueError("data_path is required when explicit train/eval files are not provided")
    rows = filter_19x19_rows(read_jsonl(Path(cfg.data_path)))
    train_rows, eval_rows = split_by_game(rows, cfg.train_games, cfg.eval_games, cfg.seed)
    if cfg.max_train_positions:
        train_rows = train_rows[: cfg.max_train_positions]
    if cfg.max_eval_positions:
        eval_rows = eval_rows[: cfg.max_eval_positions]
    print(f"limited train positions: {len(train_rows)}")
    print(f"limited eval positions: {len(eval_rows)}")
    return train_rows, eval_rows


def filter_19x19_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
    kept = [row for row in rows if int(row.get("board_size") or 19) == 19]
    skipped = len(rows) - len(kept)
    if skipped:
        print(f"[INFO] Skipped {skipped} non-19x19 positions before game split.")
    return kept


def normalize_label(value: Any) -> str:
    return str(value if value is not None else "<MISSING>")


def build_label_maps(rows: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
    maps: dict[str, dict[str, int]] = {}
    for field in CLAIM_FIELDS:
        labels = list(CLAIM_LABEL_VALUES[field])
        extras = sorted({normalize_label(row.get(field)) for row in rows} - set(labels))
        labels.extend(extras)
        maps[field] = {label: idx for idx, label in enumerate(labels)}
    return maps


def invert_label_maps(label_maps: dict[str, dict[str, int]]) -> dict[str, dict[int, str]]:
    return {field: {idx: label for label, idx in mapping.items()} for field, mapping in label_maps.items()}


def class_frequency_stats(rows: list[dict[str, Any]], label_maps: dict[str, dict[str, int]]) -> dict[str, dict[str, int]]:
    stats: dict[str, dict[str, int]] = {}
    for field in CLAIM_FIELDS:
        counts = Counter(normalize_label(row.get(field)) for row in rows)
        stats[field] = {label: counts.get(label, 0) for label in label_maps[field]}
    return stats


class GoConsistencyDataset(Dataset):
    def __init__(
        self,
        rows: list[dict[str, Any]],
        tokenizer: Any,
        label_maps: dict[str, dict[str, int]],
        max_seq_length: int,
    ):
        self.rows = rows
        self.tokenizer = tokenizer
        self.label_maps = label_maps
        self.max_seq_length = max_seq_length

    def __len__(self) -> int:
        return len(self.rows)

    def _tokenize(self, text: str) -> list[int]:
        return self.tokenizer(text, add_special_tokens=False).input_ids

    def _coord_to_row_col(self, coord: str, board_size: int) -> tuple[int, int]:
        coord = coord.strip().upper()
        col = BOARD_COLUMNS.index(coord[0])
        row = board_size - int(coord[1:])
        return row, col

    def _board_matrix_text(self, row: dict[str, Any]) -> str:
        board_size = int(row.get("board_size") or 19)
        if board_size != 19:
            raise ValueError(f"Expected 19x19 boards for matrix input, got board_size={board_size} in {row.get('id')}")
        matrix = [[0 for _ in range(19)] for _ in range(19)]
        stones = row.get("stones") or {}
        for color, value in [("black", 1), ("white", -1)]:
            for coord in stones.get(color) or []:
                try:
                    r, c = self._coord_to_row_col(str(coord), 19)
                except (ValueError, IndexError):
                    continue
                if 0 <= r < 19 and 0 <= c < 19:
                    matrix[r][c] = value
        lines = [" ".join(str(value) for value in values) for values in matrix]
        return "\n".join(lines)

    def __getitem__(self, index: int) -> dict[str, Any]:
        row = self.rows[index]
        board_text = "<|board|>\n" + self._board_matrix_text(row) + "\n"
        explanation_prefix = "<|explanation|>\n"
        rationale = str(row.get("rationale_text") or "").strip()
        explanation_text = rationale + self.tokenizer.eos_token
        claim_texts = [f"<|claim_{field}|>" for field in CLAIM_FIELDS]

        board_ids = self._tokenize(board_text)
        explanation_prefix_ids = self._tokenize(explanation_prefix)
        rationale_ids = self._tokenize(explanation_text)
        explanation_ids = explanation_prefix_ids + rationale_ids
        claim_ids_by_field = [self._tokenize(text) for text in claim_texts]
        if any(len(ids) != 1 for ids in claim_ids_by_field):
            lengths = {field: len(ids) for field, ids in zip(CLAIM_FIELDS, claim_ids_by_field)}
            raise RuntimeError(f"Claim anchors must tokenize to one token each after add_special_tokens: {lengths}")

        input_ids = board_ids + explanation_ids + [ids[0] for ids in claim_ids_by_field]
        labels = [IGNORE_INDEX] * len(input_ids)
        rationale_start = len(board_ids) + len(explanation_prefix_ids)
        rationale_end = rationale_start + len(rationale_ids)
        labels[rationale_start:rationale_end] = input_ids[rationale_start:rationale_end]

        if len(input_ids) > self.max_seq_length:
            max_rationale = self.max_seq_length - len(board_ids) - len(explanation_prefix_ids) - len(CLAIM_FIELDS)
            if max_rationale <= 0:
                raise ValueError(f"max_seq_length={self.max_seq_length} is too small for example {row.get('id')}")
            rationale_ids = rationale_ids[:max_rationale]
            explanation_ids = explanation_prefix_ids + rationale_ids
            input_ids = board_ids + explanation_ids + [ids[0] for ids in claim_ids_by_field]
            labels = [IGNORE_INDEX] * len(input_ids)
            rationale_start = len(board_ids) + len(explanation_prefix_ids)
            rationale_end = rationale_start + len(rationale_ids)
            labels[rationale_start:rationale_end] = input_ids[rationale_start:rationale_end]

        b_len = len(board_ids)
        e_len = len(explanation_ids)
        c_len = len(CLAIM_FIELDS)
        claim_start = b_len + e_len
        claim_positions = {field: claim_start + idx for idx, field in enumerate(CLAIM_FIELDS)}
        e_start = b_len
        e_end = b_len + e_len - 1
        windows = [(pos, pos, e_start, e_end) for pos in claim_positions.values()]
        claim_labels = {
            field: self.label_maps[field][normalize_label(row.get(field))]
            for field in CLAIM_FIELDS
        }
        return {
            "id": row.get("id", ""),
            "game_id": row.get("game_id", ""),
            "rationale_text": rationale,
            "raw_claims": {field: normalize_label(row.get(field)) for field in CLAIM_FIELDS},
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": [1] * len(input_ids),
            "B": b_len,
            "E": e_len,
            "C": c_len,
            "windows": windows,
            "claim_positions": claim_positions,
            "claim_labels": claim_labels,
            "board_prompt_ids": board_ids + explanation_prefix_ids,
        }


class GoConsistencyCollator:
    def __init__(self, tokenizer: Any):
        self.tokenizer = tokenizer

    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
        max_len = max(len(item["input_ids"]) for item in features)
        batch = {
            "input_ids": [],
            "labels": [],
            "attention_mask": [],
            "B": [],
            "E": [],
            "C": [],
            "windows": [],
            "ids": [],
            "game_ids": [],
            "rationale_texts": [],
            "raw_claims": [],
            "claim_positions": {field: [] for field in CLAIM_FIELDS},
            "claim_labels": {field: [] for field in CLAIM_FIELDS},
            "board_prompt_ids": [],
        }
        for item in features:
            pad_len = max_len - len(item["input_ids"])
            batch["input_ids"].append(item["input_ids"] + [self.tokenizer.pad_token_id] * pad_len)
            batch["labels"].append(item["labels"] + [IGNORE_INDEX] * pad_len)
            batch["attention_mask"].append(item["attention_mask"] + [0] * pad_len)
            batch["B"].append(item["B"])
            batch["E"].append(item["E"])
            batch["C"].append(item["C"])
            batch["windows"].append(item["windows"])
            batch["ids"].append(item["id"])
            batch["game_ids"].append(item["game_id"])
            batch["rationale_texts"].append(item["rationale_text"])
            batch["raw_claims"].append(item["raw_claims"])
            batch["board_prompt_ids"].append(item["board_prompt_ids"])
            for field in CLAIM_FIELDS:
                batch["claim_positions"][field].append(item["claim_positions"][field])
                batch["claim_labels"][field].append(item["claim_labels"][field])

        tensor_batch: dict[str, Any] = {
            "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
            "labels": torch.tensor(batch["labels"], dtype=torch.long),
            "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
            "B": torch.tensor(batch["B"], dtype=torch.long),
            "E": torch.tensor(batch["E"], dtype=torch.long),
            "C": torch.tensor(batch["C"], dtype=torch.long),
            "windows": batch["windows"],
            "ids": batch["ids"],
            "game_ids": batch["game_ids"],
            "rationale_texts": batch["rationale_texts"],
            "raw_claims": batch["raw_claims"],
            "board_prompt_ids": batch["board_prompt_ids"],
            "claim_positions": {
                field: torch.tensor(values, dtype=torch.long)
                for field, values in batch["claim_positions"].items()
            },
            "claim_labels": {
                field: torch.tensor(values, dtype=torch.long)
                for field, values in batch["claim_labels"].items()
            },
        }
        return tensor_batch


def lora_target_modules(model: nn.Module) -> list[str]:
    preferred = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    present = {name.split(".")[-1] for name, module in model.named_modules() if isinstance(module, nn.Linear)}
    return sorted(preferred & present) or sorted(present)


def load_unsloth_or_hf_model(cfg: TrainConfig) -> tuple[Any, Any]:
    dtype = torch.bfloat16 if cfg.bf16 or torch.cuda.is_bf16_supported() else torch.float16
    try:
        from unsloth import FastLanguageModel

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=cfg.model_name,
            max_seq_length=cfg.max_seq_length,
            dtype=None,
            load_in_4bit=cfg.load_in_4bit,
            full_finetuning=False,
        )
        tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model.resize_token_embeddings(len(tokenizer))
        model = FastLanguageModel.get_peft_model(
            model,
            r=cfg.lora_r,
            target_modules=lora_target_modules(model),
            lora_alpha=cfg.lora_alpha,
            lora_dropout=cfg.lora_dropout,
            bias="none",
            use_gradient_checkpointing="unsloth",
            random_state=cfg.seed,
        )
        return model, tokenizer
    except Exception as exc:
        print(f"[WARN] Unsloth load failed or is unavailable; falling back to HuggingFace PEFT. Reason: {exc}")

    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, trust_remote_code=True, use_fast=True)
    tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    quantization_config = None
    if cfg.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
    )
    model.resize_token_embeddings(len(tokenizer))
    model.config.use_cache = False
    if cfg.load_in_4bit:
        model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        lora_dropout=cfg.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=lora_target_modules(model),
    )
    model = get_peft_model(model, peft_config)
    return model, tokenizer


class GoConsistencyModel(nn.Module):
    def __init__(self, backbone: Any, label_maps: dict[str, dict[str, int]], lambda_lm: float, lambda_claim: float):
        super().__init__()
        self.backbone = backbone
        self.label_maps = label_maps
        self.lambda_lm = lambda_lm
        self.lambda_claim = lambda_claim
        hidden_size = backbone.config.hidden_size
        self.claim_heads = nn.ModuleDict(
            {field: nn.Linear(hidden_size, len(label_maps[field])) for field in CLAIM_FIELDS}
        )

    def forward(self, batch: dict[str, Any]) -> dict[str, Any]:
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        attention_mask = batch["attention_mask"]
        bec_mask = build_bec_mask(
            batch["B"],
            batch["E"],
            batch["C"],
            batch["windows"],
            device=input_ids.device,
            dtype=torch.float32,
        )
        final_mask = combine_with_padding_mask(bec_mask, attention_mask)

        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=final_mask,
            labels=labels,
            output_hidden_states=True,
            use_cache=False,
        )
        lm_loss = outputs.loss
        hidden = outputs.hidden_states[-1]

        claim_losses = []
        logits_by_field: dict[str, torch.Tensor] = {}
        batch_idx = torch.arange(input_ids.shape[0], device=input_ids.device)
        for field in CLAIM_FIELDS:
            positions = batch["claim_positions"][field].to(input_ids.device)
            targets = batch["claim_labels"][field].to(input_ids.device)
            anchor_hidden = hidden[batch_idx, positions]
            logits = self.claim_heads[field](anchor_hidden)
            logits_by_field[field] = logits
            claim_losses.append(F.cross_entropy(logits.float(), targets))

        claim_loss = torch.stack(claim_losses).mean()
        total_loss = self.lambda_lm * lm_loss + self.lambda_claim * claim_loss
        return {
            "loss": total_loss,
            "lm_loss": lm_loss.detach(),
            "claim_loss": claim_loss.detach(),
            "logits": logits_by_field,
        }


def move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]:
    moved = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            moved[key] = value.to(device)
        elif isinstance(value, dict):
            moved[key] = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in value.items()}
        else:
            moved[key] = value
    return moved


def autocast_dtype() -> torch.dtype:
    return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16


def maybe_autocast(device: torch.device):
    if device.type != "cuda":
        return nullcontext()
    if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
        return torch.amp.autocast("cuda", dtype=autocast_dtype())
    return torch.cuda.amp.autocast(dtype=autocast_dtype())


def make_grad_scaler(enabled: bool):
    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        return torch.amp.GradScaler("cuda", enabled=enabled)
    return torch.cuda.amp.GradScaler(enabled=enabled)


def macro_f1(golds: list[int], preds: list[int], num_labels: int) -> float:
    f1s = []
    for label in range(num_labels):
        tp = sum(1 for g, p in zip(golds, preds) if g == label and p == label)
        fp = sum(1 for g, p in zip(golds, preds) if g != label and p == label)
        fn = sum(1 for g, p in zip(golds, preds) if g == label and p != label)
        if tp == 0 and fp == 0 and fn == 0:
            continue
        precision = tp / (tp + fp) if tp + fp else 0.0
        recall = tp / (tp + fn) if tp + fn else 0.0
        f1s.append(2 * precision * recall / (precision + recall) if precision + recall else 0.0)
    return sum(f1s) / len(f1s) if f1s else 0.0


def evaluate(
    model: GoConsistencyModel,
    loader: DataLoader,
    device: torch.device,
    inv_label_maps: dict[str, dict[int, str]],
    output_predictions_path: Path | None = None,
) -> dict[str, Any]:
    model.eval()
    total_lm_loss = 0.0
    total_claim_loss = 0.0
    total_loss = 0.0
    batches = 0
    golds: dict[str, list[int]] = {field: [] for field in CLAIM_FIELDS}
    preds: dict[str, list[int]] = {field: [] for field in CLAIM_FIELDS}
    prediction_rows: list[dict[str, Any]] = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="eval", leave=False):
            batch = move_batch_to_device(batch, device)
            with maybe_autocast(device):
                out = model(batch)
            total_lm_loss += float(out["lm_loss"].item())
            total_claim_loss += float(out["claim_loss"].item())
            total_loss += float(out["loss"].item())
            batches += 1
            batch_preds: dict[str, list[int]] = {}
            for field in CLAIM_FIELDS:
                pred = out["logits"][field].argmax(dim=-1).detach().cpu().tolist()
                gold = batch["claim_labels"][field].detach().cpu().tolist()
                preds[field].extend(pred)
                golds[field].extend(gold)
                batch_preds[field] = pred

            for i, sample_id in enumerate(batch["ids"]):
                gold_claims = batch["raw_claims"][i]
                pred_claims = {
                    field: inv_label_maps[field][batch_preds[field][i]]
                    for field in CLAIM_FIELDS
                }
                prediction_rows.append(
                    {
                        "id": sample_id,
                        "game_id": batch["game_ids"][i],
                        "gold_claims": gold_claims,
                        "predicted_claims": pred_claims,
                        "per_claim_correct": {
                            field: gold_claims[field] == pred_claims[field]
                            for field in CLAIM_FIELDS
                        },
                        "gold_rationale_text": batch["rationale_texts"][i],
                    }
                )

    claim_metrics: dict[str, dict[str, float]] = {}
    for field in CLAIM_FIELDS:
        correct = sum(1 for g, p in zip(golds[field], preds[field]) if g == p)
        total = len(golds[field])
        claim_metrics[field] = {
            "accuracy": correct / total if total else 0.0,
            "macro_f1": macro_f1(golds[field], preds[field], len(inv_label_maps[field])),
        }

    avg_claim_accuracy = sum(m["accuracy"] for m in claim_metrics.values()) / len(CLAIM_FIELDS)
    avg_macro_f1 = sum(m["macro_f1"] for m in claim_metrics.values()) / len(CLAIM_FIELDS)
    lm_loss = total_lm_loss / max(1, batches)
    metrics = {
        "lm_loss": lm_loss,
        "perplexity": math.exp(lm_loss) if lm_loss < 20 else None,
        "claim_loss": total_claim_loss / max(1, batches),
        "total_loss": total_loss / max(1, batches),
        "claims": claim_metrics,
        "avg_claim_accuracy": avg_claim_accuracy,
        "avg_macro_f1": avg_macro_f1,
    }
    if output_predictions_path is not None:
        with output_predictions_path.open("w") as f:
            for row in prediction_rows:
                f.write(json.dumps(row) + "\n")
    return metrics


def majority_baseline(train_rows: list[dict[str, Any]], eval_rows: list[dict[str, Any]], label_maps: dict[str, dict[str, int]]) -> dict[str, Any]:
    metrics = {}
    for field in CLAIM_FIELDS:
        majority_label = Counter(normalize_label(row.get(field)) for row in train_rows).most_common(1)[0][0]
        majority_id = label_maps[field][majority_label]
        gold = [label_maps[field][normalize_label(row.get(field))] for row in eval_rows]
        pred = [majority_id] * len(gold)
        metrics[field] = {
            "majority_label": majority_label,
            "accuracy": sum(1 for g, p in zip(gold, pred) if g == p) / len(gold),
            "macro_f1": macro_f1(gold, pred, len(label_maps[field])),
        }
    return {
        "claims": metrics,
        "avg_claim_accuracy": sum(m["accuracy"] for m in metrics.values()) / len(CLAIM_FIELDS),
        "avg_macro_f1": sum(m["macro_f1"] for m in metrics.values()) / len(CLAIM_FIELDS),
    }


def save_checkpoint(model: GoConsistencyModel, tokenizer: Any, output_dir: Path, label_maps: dict[str, dict[str, int]], metrics: dict[str, Any]) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    model.backbone.save_pretrained(output_dir / "adapter")
    tokenizer.save_pretrained(output_dir / "tokenizer")
    torch.save(model.claim_heads.state_dict(), output_dir / "claim_heads.pt")
    (output_dir / "label_maps.json").write_text(json.dumps(label_maps, indent=2))
    (output_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))


def print_metrics_table(metrics: dict[str, Any]) -> None:
    print("\nEvaluation metrics")
    print(f"lm_loss          {metrics['lm_loss']:.4f}")
    ppl = metrics["perplexity"]
    print(f"perplexity       {ppl:.4f}" if ppl is not None else "perplexity       <overflow>")
    print(f"claim_loss       {metrics['claim_loss']:.4f}")
    print(f"total_loss       {metrics['total_loss']:.4f}")
    print(f"avg_claim_acc    {metrics['avg_claim_accuracy']:.4f}")
    print(f"avg_macro_f1     {metrics['avg_macro_f1']:.4f}")
    print("\nclaim_type                         accuracy  macro_f1")
    for field in CLAIM_FIELDS:
        item = metrics["claims"][field]
        print(f"{field:<34} {item['accuracy']:.4f}    {item['macro_f1']:.4f}")


def print_split_metrics(epoch: int, train_metrics: dict[str, Any], eval_metrics: dict[str, Any]) -> None:
    print(f"\nEpoch {epoch} claim metrics")
    print("split      claim_loss  total_loss  avg_claim_acc  avg_macro_f1")
    for split, metrics in [("train", train_metrics), ("eval", eval_metrics)]:
        print(
            f"{split:<10} {metrics['claim_loss']:.4f}      {metrics['total_loss']:.4f}      "
            f"{metrics['avg_claim_accuracy']:.4f}         {metrics['avg_macro_f1']:.4f}"
        )


def qualitative_generation(
    model: GoConsistencyModel,
    tokenizer: Any,
    dataset: GoConsistencyDataset,
    device: torch.device,
    inv_label_maps: dict[str, dict[int, str]],
    output_path: Path,
    num_samples: int,
    max_new_tokens: int,
) -> None:
    if num_samples <= 0:
        return
    model.eval()
    rows = []
    for idx in range(min(num_samples, len(dataset))):
        item = dataset[idx]
        prompt_ids = torch.tensor([item["board_prompt_ids"]], dtype=torch.long, device=device)
        with torch.no_grad():
            with maybe_autocast(device):
                generated = model.backbone.generate(
                    input_ids=prompt_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
        generated_text = tokenizer.decode(generated[0][prompt_ids.shape[1] :], skip_special_tokens=True).strip()

        single_batch = GoConsistencyCollator(tokenizer)([item])
        single_batch = move_batch_to_device(single_batch, device)
        with torch.no_grad():
            with maybe_autocast(device):
                out = model(single_batch)
        pred_claims = {
            field: inv_label_maps[field][int(out["logits"][field].argmax(dim=-1).item())]
            for field in CLAIM_FIELDS
        }
        row = {
            "id": item["id"],
            "gold_rationale_text": item["rationale_text"],
            "generated_explanation": generated_text,
            "gold_claims": item["raw_claims"],
            "predicted_claims": pred_claims,
        }
        rows.append(row)
        print("\nQUAL SAMPLE", row["id"])
        print("gold:", row["gold_rationale_text"])
        print("generated:", row["generated_explanation"])
        print("gold_claims:", row["gold_claims"])
        print("predicted_claims:", row["predicted_claims"])

    with output_path.open("w") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")


def main() -> None:
    cfg = parse_args()
    output_dir = Path(cfg.output_dir)
    if output_dir.exists():
        print(f"[INFO] Reusing output directory {output_dir}")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2))

    random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)
    train_rows, eval_rows = load_train_eval_rows(cfg)
    label_maps = build_label_maps(train_rows + eval_rows)
    inv_label_maps = invert_label_maps(label_maps)
    stats = {
        "train": class_frequency_stats(train_rows, label_maps),
        "eval": class_frequency_stats(eval_rows, label_maps),
        "majority_baseline": majority_baseline(train_rows, eval_rows, label_maps),
    }
    (output_dir / "label_maps.json").write_text(json.dumps(label_maps, indent=2))
    (output_dir / "class_frequency_stats.json").write_text(json.dumps(stats, indent=2))

    backbone, tokenizer = load_unsloth_or_hf_model(cfg)
    from transformers import get_linear_schedule_with_warmup

    train_dataset = GoConsistencyDataset(train_rows, tokenizer, label_maps, cfg.max_seq_length)
    eval_dataset = GoConsistencyDataset(eval_rows, tokenizer, label_maps, cfg.max_seq_length)
    collator = GoConsistencyCollator(tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=cfg.per_device_train_batch_size, shuffle=True, collate_fn=collator)
    eval_loader = DataLoader(eval_dataset, batch_size=cfg.per_device_eval_batch_size, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GoConsistencyModel(backbone, label_maps, cfg.lambda_lm, cfg.lambda_claim)
    model.claim_heads.to(device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    steps_per_epoch = math.ceil(len(train_loader) / cfg.gradient_accumulation_steps)
    total_steps = max(1, steps_per_epoch * cfg.num_train_epochs)
    warmup_steps = int(total_steps * cfg.warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    scaler = make_grad_scaler(enabled=torch.cuda.is_available() and not (cfg.bf16 or torch.cuda.is_bf16_supported()))

    best_score = -1.0
    all_epoch_metrics: list[dict[str, Any]] = []
    global_step = 0
    for epoch in range(1, cfg.num_train_epochs + 1):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        progress = tqdm(train_loader, desc=f"train epoch {epoch}")
        running = Counter()
        for step, batch in enumerate(progress, 1):
            batch = move_batch_to_device(batch, device)
            with maybe_autocast(device):
                out = model(batch)
                loss = out["loss"] / cfg.gradient_accumulation_steps
            scaler.scale(loss).backward()
            running["loss"] += float(out["loss"].detach().item())
            running["lm_loss"] += float(out["lm_loss"].detach().item())
            running["claim_loss"] += float(out["claim_loss"].detach().item())

            if step % cfg.gradient_accumulation_steps == 0 or step == len(train_loader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                global_step += 1

            if step % cfg.log_every == 0:
                denom = cfg.log_every
                progress.set_postfix(
                    loss=running["loss"] / denom,
                    lm=running["lm_loss"] / denom,
                    claim=running["claim_loss"] / denom,
                )
                running.clear()

        train_pred_path = output_dir / f"train_predictions_epoch_{epoch}.jsonl"
        eval_pred_path = output_dir / f"eval_predictions_epoch_{epoch}.jsonl"
        train_metrics = evaluate(model, train_loader, device, inv_label_maps, train_pred_path)
        eval_metrics = evaluate(model, eval_loader, device, inv_label_maps, eval_pred_path)
        epoch_metrics = {
            "epoch": epoch,
            "train": train_metrics,
            "eval": eval_metrics,
        }
        all_epoch_metrics.append(epoch_metrics)
        print_split_metrics(epoch, train_metrics, eval_metrics)
        print("\nTrain metrics")
        print_metrics_table(train_metrics)
        print("\nEval metrics")
        print_metrics_table(eval_metrics)
        score = eval_metrics["avg_macro_f1"]
        if score > best_score:
            best_score = score
            save_checkpoint(model, tokenizer, output_dir / "best_checkpoint", label_maps, epoch_metrics)
            shutil.copy2(train_pred_path, output_dir / "best_train_predictions.jsonl")
            shutil.copy2(eval_pred_path, output_dir / "best_eval_predictions.jsonl")

    final_metrics = all_epoch_metrics[-1]
    (output_dir / "metrics.json").write_text(json.dumps({"epochs": all_epoch_metrics, "final": final_metrics}, indent=2))
    qualitative_generation(
        model,
        tokenizer,
        eval_dataset,
        device,
        inv_label_maps,
        output_dir / "qualitative_generations.jsonl",
        cfg.generate_samples,
        cfg.max_new_tokens,
    )
    print_split_metrics(final_metrics["epoch"], final_metrics["train"], final_metrics["eval"])
    print(f"\nSaved outputs to {output_dir}")


if __name__ == "__main__":
    main()
