from __future__ import annotations

import argparse
import os
import random
import re
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

LABEL_MAP = {"SUPPORTS": 0, "REFUTES": 1, "NOT ENOUGH INFO": 2}
ID_TO_LABEL = {0: "SUPPORTS", 1: "REFUTES", 2: "NOT ENOUGH INFO"}
NUM_CLASSES = 3
LABEL_TOKENS = ["[SUPPORTS]", "[REFUTES]", "[NEI]"]
STOPWORDS = {
    "the", "a", "an", "of", "in", "on", "at", "for", "to", "and", "or", "is", "was", "were",
    "are", "be", "been", "by", "with", "as", "that", "this", "it", "from", "into", "about"
}
STRICT_EVIDENCE_VARIANTS = {"evidence_only_strict", "evidence_only_random_labels"}


@dataclass
class ScratchTransformerConfig:
    dataset_name: str = "copenlu/fever_gold_evidence"
    num_train_samples: int = 50_000
    num_eval_samples: int = 5_000
    max_seq_len: int = 256
    seed: int = 42

    d_model: int = 256
    n_layers: int = 4
    n_heads: int = 8
    d_ff: int = 1024
    dropout: float = 0.1

    batch_size: int = 32
    num_epochs: int = 10
    lr: float = 3e-4
    weight_decay: float = 0.01
    consistency_loss_weight: float = 0.5
    grad_clip: float = 1.0
    warmup_steps: int = 500

    pooling_modes: Tuple[str, ...] = (
        "no_consistency_loss",
        "evidence_only_pooling",
        "evidence_only_strict",
        "full_sequence_pooling",
        "claim_only_pooling",
        "evidence_only_random_labels",
    )

    model_name: str = "gpt2"
    results_path: str = "results_fever_scratch_transformer.csv"
    smoke_test: bool = False
    require_gpu: bool = False
    device: str = "auto"


_tokenizer = None


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def resolve_device(cfg: ScratchTransformerConfig) -> torch.device:
    if cfg.require_gpu and not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available but --require-gpu was set.")
    if cfg.device == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(cfg.device)


def get_tokenizer(model_name: str = "gpt2"):
    global _tokenizer
    if _tokenizer is None:
        from transformers import GPT2TokenizerFast
        tok = GPT2TokenizerFast.from_pretrained(model_name)
        tok.add_special_tokens({
            "bos_token": "[BOS]",
            "eos_token": "[EOS]",
            "pad_token": "[PAD]",
            "additional_special_tokens": ["[SEP]", "[LABELSEP]", "[SUPPORTS]", "[REFUTES]", "[NEI]"],
        })
        _tokenizer = tok
    return _tokenizer


def load_fever_data(cfg: ScratchTransformerConfig):
    from datasets import load_dataset

    train_ds = load_dataset(cfg.dataset_name, split="train")
    val_ds = load_dataset(cfg.dataset_name, split="validation")

    def parse_split(hf_ds, n_samples, split_name="train"):
        rng = random.Random(cfg.seed)
        indices = list(range(len(hf_ds)))
        rng.shuffle(indices)
        samples = []
        for idx in indices:
            ex = hf_ds[idx]
            label_str = ex["label"]
            if label_str not in LABEL_MAP:
                continue
            evidence_pieces = []
            for ev in ex["evidence"]:
                if len(ev) >= 3 and ev[2]:
                    evidence_pieces.append(str(ev[2]))
            evidence_text = " ".join(evidence_pieces) if evidence_pieces else "No evidence available."
            samples.append({
                "evidence": evidence_text,
                "claim": str(ex["claim"]),
                "label": LABEL_MAP[label_str],
                "label_str": label_str,
            })
            if len(samples) >= n_samples:
                break
        print(f"{split_name}: {len(samples)} samples loaded | labels: {dict(Counter(s['label_str'] for s in samples))}")
        return samples

    return parse_split(train_ds, cfg.num_train_samples, "train"), parse_split(val_ds, cfg.num_eval_samples, "eval")


def make_smoke_data(n_train: int = 64, n_eval: int = 32, seed: int = 42):
    rng = random.Random(seed)
    labels = list(LABEL_MAP.keys())
    evidence_templates = [
        "The study showed that X is true in most observed cases.",
        "Research indicates that Y is false and has been refuted.",
        "No significant evidence was found in the literature.",
        "Experiments confirmed that Z leads to positive outcomes.",
        "Scientists disputed the claim about W being correct.",
    ]
    claim_templates = [
        "X is commonly found to be true.",
        "Y has been proven false by scientists.",
        "There is no clear evidence about this topic.",
        "Z leads to positive results overall.",
        "W is confirmed by the scientific community.",
    ]

    def make_samples(n):
        out = []
        for _ in range(n):
            ls = rng.choice(labels)
            out.append({
                "evidence": rng.choice(evidence_templates),
                "claim": rng.choice(claim_templates),
                "label": LABEL_MAP[ls],
                "label_str": ls,
            })
        return out

    return make_samples(n_train), make_samples(n_eval)


class FEVERTokenDataset(Dataset):
    def __init__(self, samples: List[dict], max_seq_len: int = 256, model_name: str = "gpt2"):
        self.max_seq_len = max_seq_len
        self.tokenizer = get_tokenizer(model_name)
        bos_id = self.tokenizer.bos_token_id
        sep_id = self.tokenizer.convert_tokens_to_ids("[SEP]")
        labelsep_id = self.tokenizer.convert_tokens_to_ids("[LABELSEP]")
        eos_id = self.tokenizer.eos_token_id
        pad_id = self.tokenizer.pad_token_id
        self.records = []
        for s in samples:
            ev_ids = self.tokenizer.encode(s["evidence"], add_special_tokens=False)
            cl_ids = self.tokenizer.encode(s["claim"], add_special_tokens=False)
            label_tok_id = self.tokenizer.convert_tokens_to_ids(LABEL_TOKENS[s["label"]])
            overhead = 5
            ev_budget = max_seq_len - overhead - len(cl_ids)
            if ev_budget < 4:
                ev_budget = 4
            cl_ids = cl_ids[:max_seq_len - overhead - ev_budget]
            ev_ids = ev_ids[:ev_budget]
            ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [label_tok_id] + [eos_id]
            sep_pos = 1 + len(ev_ids)
            labelsep_pos = sep_pos + 1 + len(cl_ids)
            pad_len = max_seq_len - len(ids)
            attn = [1] * len(ids) + [0] * max(0, pad_len)
            ids = (ids + [pad_id] * max(0, pad_len))[:max_seq_len]
            attn = attn[:max_seq_len]
            self.records.append({
                "input_ids": torch.tensor(ids, dtype=torch.long),
                "attention_mask": torch.tensor(attn, dtype=torch.long),
                "sep_pos": sep_pos,
                "labelsep_pos": labelsep_pos,
                "label": s["label"],
                "label_str": s["label_str"],
                "evidence": s["evidence"],
                "claim": s["claim"],
            })

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return self.records[idx]


def collate_fn(batch):
    return {
        "input_ids": torch.stack([b["input_ids"] for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels": torch.tensor([b["label"] for b in batch], dtype=torch.long),
        "sep_pos": [b["sep_pos"] for b in batch],
        "labelsep_pos": [b["labelsep_pos"] for b in batch],
    }


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
        h = self.ln1(x)
        attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
        x = x + attn_out
        x = x + self.ff(self.ln2(x))
        return x


class ScratchTransformerWithConsistencyHead(nn.Module):
    def __init__(self, vocab_size: int, cfg: ScratchTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.token_emb = nn.Embedding(vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([
            DecoderBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, vocab_size, bias=False)
        self.consistency_head = nn.Linear(cfg.d_model, NUM_CLASSES)
        self.lm_head.weight = self.token_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def _causal_mask(self, T: int, device: torch.device):
        return torch.triu(torch.full((T, T), float("-inf"), device=device), diagonal=1)

    def _strict_evidence_hidden_states(self, hidden_states, sep_pos):
        masked = hidden_states.clone()
        for i in range(masked.size(0)):
            if sep_pos[i] < masked.size(1):
                masked[i, sep_pos[i]:, :] = 0.0
        return masked

    def _pool(self, hidden_states, pooling_mode, sep_pos, labelsep_pos, attention_mask):
        B, T, _ = hidden_states.shape
        pooled = []
        for i in range(B):
            if pooling_mode in {"evidence_only_pooling", "evidence_only_strict", "evidence_only_random_labels"}:
                start, end = 1, max(2, sep_pos[i])
            elif pooling_mode == "claim_only_pooling":
                start, end = sep_pos[i] + 1, max(sep_pos[i] + 2, labelsep_pos[i])
            elif pooling_mode == "full_sequence_pooling":
                positions = attention_mask[i].nonzero(as_tuple=True)[0]
                start, end = (positions[0].item(), positions[-1].item() + 1) if len(positions) > 0 else (0, T)
            else:
                start, end = 0, T
            start = max(0, min(start, T - 1))
            end = max(start + 1, min(end, T))
            pooled.append(hidden_states[i, start:end, :].mean(dim=0))
        return torch.stack(pooled, dim=0)

    def forward(self, input_ids, attention_mask, pooling_mode="evidence_only_pooling", sep_pos=None, labelsep_pos=None):
        B, T = input_ids.shape
        device = input_ids.device
        pos = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
        x = self.token_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(x)
        key_padding_mask = attention_mask == 0
        attn_mask = self._causal_mask(T, device)
        for block in self.blocks:
            x = block(x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        hidden_states = self.ln_f(x)
        if labelsep_pos is not None:
            ls_idx = torch.tensor(labelsep_pos, dtype=torch.long, device=device).clamp(0, T - 2)
            h_at_ls = hidden_states[torch.arange(B, device=device), ls_idx, :]
        else:
            h_at_ls = hidden_states[:, -2, :]
        lm_logits = self.lm_head(h_at_ls)
        cons_hidden = self._strict_evidence_hidden_states(hidden_states, sep_pos) if pooling_mode in STRICT_EVIDENCE_VARIANTS else hidden_states
        pooled = self._pool(cons_hidden, pooling_mode, sep_pos, labelsep_pos, attention_mask)
        cons_logits = self.consistency_head(pooled)
        return lm_logits, cons_logits

    def get_full_lm_logits(self, input_ids, attention_mask=None):
        B, T = input_ids.shape
        device = input_ids.device
        pos = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
        x = self.token_emb(input_ids) + self.pos_emb(pos)
        x = self.drop(x)
        key_padding_mask = None if attention_mask is None else (attention_mask == 0)
        attn_mask = self._causal_mask(T, device)
        for block in self.blocks:
            x = block(x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        x = self.ln_f(x)
        return self.lm_head(x)


def maybe_randomize_labels(labels, pooling_mode, seed_offset=0):
    if pooling_mode != "evidence_only_random_labels":
        return labels
    shift = (seed_offset % (NUM_CLASSES - 1)) + 1
    return (labels + shift) % NUM_CLASSES


def build_optimizer(model, cfg: ScratchTransformerConfig):
    decay_params = []
    no_decay_params = []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.ndim < 2 or "bias" in n or "ln" in n:
            no_decay_params.append(p)
        else:
            decay_params.append(p)
    return torch.optim.AdamW(
        [
            {"params": decay_params, "weight_decay": cfg.weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0},
        ],
        lr=cfg.lr,
        betas=(0.9, 0.95),
    )


def build_scheduler(optimizer, total_steps: int, warmup_steps: int):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.1, 0.5 * (1.0 + np.cos(np.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_epoch(model, loader, optimizer, scheduler, device, pooling_mode, consistency_loss_weight, grad_clip=1.0):
    model.train()
    total_lm = total_cons = 0.0
    n_batches = 0
    for batch_idx, batch in enumerate(loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        sep_pos = batch["sep_pos"]
        labelsep_pos = batch["labelsep_pos"]
        B, T = input_ids.shape
        optimizer.zero_grad(set_to_none=True)
        lm_logits, cons_logits = model(input_ids, attention_mask, pooling_mode=pooling_mode, sep_pos=sep_pos, labelsep_pos=labelsep_pos)
        ls_idx = torch.tensor(labelsep_pos, dtype=torch.long, device=device).clamp(0, T - 2)
        label_token_targets = input_ids[torch.arange(B, device=device), (ls_idx + 1).clamp(0, T - 1)]
        lm_loss = F.cross_entropy(lm_logits, label_token_targets)
        cons_targets = maybe_randomize_labels(labels, pooling_mode, batch_idx)
        cons_loss = F.cross_entropy(cons_logits, cons_targets)
        loss = lm_loss if pooling_mode == "no_consistency_loss" else lm_loss + consistency_loss_weight * cons_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        scheduler.step()
        total_lm += lm_loss.item()
        total_cons += cons_loss.item()
        n_batches += 1
    return total_lm / max(1, n_batches), total_cons / max(1, n_batches)


@torch.no_grad()
def evaluate_classification(model, loader, device, pooling_mode):
    model.eval()
    correct = total = 0
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        sep_pos = batch["sep_pos"]
        labelsep_pos = batch["labelsep_pos"]
        _, cons_logits = model(input_ids, attention_mask, pooling_mode=pooling_mode, sep_pos=sep_pos, labelsep_pos=labelsep_pos)
        preds = cons_logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return correct / max(1, total)


@torch.no_grad()
def evaluate_generation_accuracy(model, eval_samples, device, max_seq_len=256, n_samples=500, model_name="gpt2"):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    samples_to_use = eval_samples[:n_samples]
    correct = 0
    for s in samples_to_use:
        ev_ids = tokenizer.encode(s["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(s["claim"], add_special_tokens=False)
        label_tok_id = label_tok_ids[s["label"]]
        overhead = 3
        ev_budget = max_seq_len - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_budget = 4
        cl_ids = cl_ids[:max_seq_len - overhead - ev_budget]
        ev_ids = ev_ids[:ev_budget]
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_t = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_t)
        pred_id = full_logits[0, -1, :].argmax().item()
        if pred_id == label_tok_id:
            correct += 1
    return correct / max(1, len(samples_to_use))


def _claim_signature(text: str):
    toks = re.findall(r"[a-z0-9]+", text.lower())
    return {t for t in toks if t not in STOPWORDS and len(t) > 2}


def build_matched_claim_candidates(records, max_per_item=32):
    buckets = defaultdict(list)
    signatures = []
    for i, r in enumerate(records):
        sig = _claim_signature(r["claim"])
        signatures.append(sig)
        for tok in list(sig)[:12]:
            buckets[tok].append(i)
    candidates = {}
    for i, r in enumerate(records):
        orig_label = r["label"]
        sig = signatures[i]
        pooled = set()
        for tok in list(sig)[:12]:
            pooled.update(buckets.get(tok, []))
        scored = []
        for j in pooled:
            if j == i or records[j]["label"] == orig_label:
                continue
            other_sig = signatures[j]
            if not sig or not other_sig:
                continue
            inter = len(sig & other_sig)
            union = len(sig | other_sig)
            score = inter / union if union else 0.0
            if score > 0:
                scored.append((score, j))
        scored.sort(reverse=True)
        candidates[i] = [j for _, j in scored[:max_per_item]]
    return candidates


@torch.no_grad()
def evaluate_counterfactual_swap(model, eval_dataset, device, pooling_mode, n_pairs=500, seed=42, model_name="gpt2", matched_claim_only=False):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    rng = random.Random(seed)
    records = eval_dataset.records
    n = min(n_pairs, len(records))
    indices = list(range(len(records)))
    rng.shuffle(indices)
    eval_indices = indices[:n]
    label_to_idxs = {0: [], 1: [], 2: []}
    for i, r in enumerate(records):
        label_to_idxs[r["label"]].append(i)
    matched_candidates = build_matched_claim_candidates(records) if matched_claim_only else None
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    pad_id = tokenizer.pad_token_id
    eos_id = tokenizer.eos_token_id
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    T = eval_dataset.max_seq_len
    cls_follows_swap = cls_follows_orig = gen_follows_swap = gen_follows_orig = n_valid = 0
    for idx in eval_indices:
        orig = records[idx]
        orig_label = orig["label"]
        if matched_claim_only:
            candidates = matched_candidates.get(idx, [])
            if not candidates:
                continue
            swap_idx = rng.choice(candidates)
            swap_rec = records[swap_idx]
            swap_label = swap_rec["label"]
        else:
            other_labels = [l for l in range(NUM_CLASSES) if l != orig_label]
            swap_label = rng.choice(other_labels)
            candidates = label_to_idxs[swap_label]
            if not candidates:
                continue
            swap_rec = records[rng.choice(candidates)]
        ev_ids = tokenizer.encode(swap_rec["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(orig["claim"], add_special_tokens=False)
        overhead = 5
        ev_budget = T - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_ids = ev_ids[:T - 9]
            cl_ids = cl_ids[:4]
        else:
            ev_ids = ev_ids[:ev_budget]
        new_label_tok_id = label_tok_ids[swap_label]
        ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [new_label_tok_id] + [eos_id]
        sep_pos_new = 1 + len(ev_ids)
        labelsep_pos_new = sep_pos_new + 1 + len(cl_ids)
        pad_len = T - len(ids)
        attn = [1] * len(ids) + [0] * max(0, pad_len)
        ids_pad = (ids + [pad_id] * max(0, pad_len))[:T]
        attn = attn[:T]
        swapped_ids = torch.tensor([ids_pad], dtype=torch.long, device=device)
        swapped_mask = torch.tensor([attn], dtype=torch.long, device=device)
        _, cons_logits = model(swapped_ids, swapped_mask, pooling_mode=pooling_mode, sep_pos=[sep_pos_new], labelsep_pos=[labelsep_pos_new])
        pred_cls = cons_logits[0].argmax().item()
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_prompt = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_prompt)
        pred_gen_tok = full_logits[0, -1, :].argmax().item()
        pred_gen_label = next((li for li, ltid in enumerate(label_tok_ids) if pred_gen_tok == ltid), None)
        if pred_cls == swap_label:
            cls_follows_swap += 1
        if pred_cls == orig_label:
            cls_follows_orig += 1
        if pred_gen_label is not None:
            if pred_gen_label == swap_label:
                gen_follows_swap += 1
            if pred_gen_label == orig_label:
                gen_follows_orig += 1
        n_valid += 1
    d = max(1, n_valid)
    return cls_follows_swap / d, cls_follows_orig / d, gen_follows_swap / d, gen_follows_orig / d


@torch.no_grad()
def evaluate_shuffled_pairing(model, eval_dataset, device, pooling_mode, n_pairs=500, seed=42, model_name="gpt2"):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    rng = random.Random(seed + 1)
    records = eval_dataset.records
    n = min(n_pairs, len(records))
    indices = list(range(len(records)))
    rng.shuffle(indices)
    eval_indices = indices[:n]
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    pad_id = tokenizer.pad_token_id
    eos_id = tokenizer.eos_token_id
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    T = eval_dataset.max_seq_len
    offset = max(1, n // 3)
    shuffled_pairs = [(eval_indices[i], eval_indices[(i + offset) % n]) for i in range(n)]
    cls_correct = gen_correct = n_valid = 0
    for orig_idx, ev_idx in shuffled_pairs:
        orig = records[orig_idx]
        ev_rec = records[ev_idx]
        orig_label = orig["label"]
        ev_ids = tokenizer.encode(ev_rec["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(orig["claim"], add_special_tokens=False)
        overhead = 5
        ev_budget = T - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_ids = ev_ids[:T - 9]
            cl_ids = cl_ids[:4]
        else:
            ev_ids = ev_ids[:ev_budget]
        label_tok_id = label_tok_ids[orig_label]
        ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [label_tok_id] + [eos_id]
        sep_pos_new = 1 + len(ev_ids)
        labelsep_pos_new = sep_pos_new + 1 + len(cl_ids)
        pad_len = T - len(ids)
        attn = [1] * len(ids) + [0] * max(0, pad_len)
        ids_pad = (ids + [pad_id] * max(0, pad_len))[:T]
        attn = attn[:T]
        shuf_ids = torch.tensor([ids_pad], dtype=torch.long, device=device)
        shuf_mask = torch.tensor([attn], dtype=torch.long, device=device)
        _, cons_logits = model(shuf_ids, shuf_mask, pooling_mode=pooling_mode, sep_pos=[sep_pos_new], labelsep_pos=[labelsep_pos_new])
        pred_cls = cons_logits[0].argmax().item()
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_prompt = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_prompt)
        pred_gen_tok = full_logits[0, -1, :].argmax().item()
        pred_gen_label = next((li for li, ltid in enumerate(label_tok_ids) if pred_gen_tok == ltid), None)
        if pred_cls == orig_label:
            cls_correct += 1
        if pred_gen_label == orig_label:
            gen_correct += 1
        n_valid += 1
    d = max(1, n_valid)
    return cls_correct / d, gen_correct / d


def run_experiment(cfg: ScratchTransformerConfig):
    set_seed(cfg.seed)
    device = resolve_device(cfg)
    print(f"Using device: {device}")
    if cfg.smoke_test:
        train_samples, eval_samples = make_smoke_data(seed=cfg.seed)
    else:
        train_samples, eval_samples = load_fever_data(cfg)
    tokenizer = get_tokenizer(cfg.model_name)
    train_dataset = FEVERTokenDataset(train_samples, max_seq_len=cfg.max_seq_len, model_name=cfg.model_name)
    eval_dataset = FEVERTokenDataset(eval_samples, max_seq_len=cfg.max_seq_len, model_name=cfg.model_name)
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    eval_loader = DataLoader(eval_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    results = []
    total_steps = len(train_loader) * cfg.num_epochs

    for variant in cfg.pooling_modes:
        print(f"\n=== Variant: {variant} ===")
        model = ScratchTransformerWithConsistencyHead(len(tokenizer), cfg).to(device)
        optimizer = build_optimizer(model, cfg)
        scheduler = build_scheduler(optimizer, total_steps=total_steps, warmup_steps=min(cfg.warmup_steps, max(1, total_steps // 5)))
        final_lm_loss, final_cons_loss = None, None
        start = time.time()
        for epoch in range(cfg.num_epochs):
            final_lm_loss, final_cons_loss = train_epoch(
                model, train_loader, optimizer, scheduler, device,
                pooling_mode=variant,
                consistency_loss_weight=cfg.consistency_loss_weight,
                grad_clip=cfg.grad_clip,
            )
            print(f"epoch {epoch+1}/{cfg.num_epochs} | lm={final_lm_loss:.4f} | cons={final_cons_loss:.4f}")
        cls_claim_acc = evaluate_classification(model, eval_loader, device, variant)
        gen_claim_acc = evaluate_generation_accuracy(model, eval_samples, device, cfg.max_seq_len, min(500, len(eval_samples)), cfg.model_name)
        cf_cls_swap, cf_cls_orig, cf_gen_swap, cf_gen_orig = evaluate_counterfactual_swap(
            model, eval_dataset, device, variant, n_pairs=min(500, len(eval_samples)), seed=cfg.seed, model_name=cfg.model_name
        )
        m_cf_cls_swap, m_cf_cls_orig, _, _ = evaluate_counterfactual_swap(
            model, eval_dataset, device, variant, n_pairs=min(500, len(eval_samples)), seed=cfg.seed, model_name=cfg.model_name, matched_claim_only=True
        )
        shuffled_cls_acc, shuffled_gen_acc = evaluate_shuffled_pairing(
            model, eval_dataset, device, variant, n_pairs=min(500, len(eval_samples)), seed=cfg.seed, model_name=cfg.model_name
        )
        elapsed = time.time() - start
        n_params = sum(p.numel() for p in model.parameters())
        results.append({
            "variant": variant,
            "params": n_params,
            "train_minutes": elapsed / 60.0,
            "final_lm_loss": round(final_lm_loss, 4),
            "final_cons_loss": round(final_cons_loss, 4),
            "gen_claim_acc": round(gen_claim_acc, 4),
            "cls_claim_acc": round(cls_claim_acc, 4),
            "cfact_gen_follows_swap": round(cf_gen_swap, 4),
            "cfact_gen_follows_orig": round(cf_gen_orig, 4),
            "cfact_cls_follows_swap": round(cf_cls_swap, 4),
            "cfact_cls_follows_orig": round(cf_cls_orig, 4),
            "matched_cfact_cls_follows_swap": round(m_cf_cls_swap, 4),
            "matched_cfact_cls_follows_orig": round(m_cf_cls_orig, 4),
            "shuffled_gen_acc": round(shuffled_gen_acc, 4),
            "shuffled_cls_acc": round(shuffled_cls_acc, 4),
        })
    df = pd.DataFrame(results)
    df.to_csv(cfg.results_path, index=False)
    return df


def format_results_markdown(df: pd.DataFrame, cfg: ScratchTransformerConfig) -> str:
    cols = [
        "variant", "params", "train_minutes", "final_lm_loss", "final_cons_loss", "gen_claim_acc", "cls_claim_acc",
        "cfact_cls_follows_swap", "cfact_cls_follows_orig", "matched_cfact_cls_follows_swap", "matched_cfact_cls_follows_orig",
        "cfact_gen_follows_swap", "cfact_gen_follows_orig", "shuffled_cls_acc", "shuffled_gen_acc",
    ]
    lines = [
        "# FEVER From-Scratch Transformer Claim-Consistency Results\n",
        "## Setup\n",
        f"- Dataset: `{cfg.dataset_name}`",
        f"- Train samples: {cfg.num_train_samples:,} | Eval samples: {cfg.num_eval_samples:,}",
        f"- Model: d_model={cfg.d_model}, layers={cfg.n_layers}, heads={cfg.n_heads}, d_ff={cfg.d_ff}",
        f"- Epochs: {cfg.num_epochs} | Batch size: {cfg.batch_size} | LR: {cfg.lr}",
        f"- Consistency loss weight: {cfg.consistency_loss_weight}",
        "",
        "## Results\n",
        "| " + " | ".join(cols) + " |",
        "|" + "|".join(["---"] * len(cols)) + "|",
    ]
    for _, row in df.iterrows():
        vals = []
        for c in cols:
            v = row.get(c, "")
            if isinstance(v, float):
                vals.append(f"{v:.4f}")
            else:
                vals.append(str(v))
        lines.append("| " + " | ".join(vals) + " |")
    return "\n".join(lines)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="FEVER from-scratch transformer claim-consistency experiment")
    parser.add_argument("--train-samples", type=int, default=50_000)
    parser.add_argument("--eval-samples", type=int, default=5_000)
    parser.add_argument("--max-seq-len", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--consistency-loss-weight", type=float, default=0.5)
    parser.add_argument("--warmup-steps", type=int, default=500)
    parser.add_argument("--d-model", type=int, default=256)
    parser.add_argument("--n-layers", type=int, default=4)
    parser.add_argument("--n-heads", type=int, default=8)
    parser.add_argument("--d-ff", type=int, default=1024)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--variants", nargs="+", default=[
        "no_consistency_loss", "evidence_only_pooling", "evidence_only_strict",
        "full_sequence_pooling", "claim_only_pooling", "evidence_only_random_labels"
    ])
    parser.add_argument("--output-csv", type=str, default="results_fever_scratch_transformer.csv")
    parser.add_argument("--require-gpu", action="store_true")
    parser.add_argument("--smoke-test", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    cfg = ScratchTransformerConfig(
        num_train_samples=64 if args.smoke_test else args.train_samples,
        num_eval_samples=32 if args.smoke_test else args.eval_samples,
        max_seq_len=args.max_seq_len,
        num_epochs=1 if args.smoke_test else args.epochs,
        batch_size=min(8, args.batch_size) if args.smoke_test else args.batch_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        consistency_loss_weight=args.consistency_loss_weight,
        warmup_steps=args.warmup_steps,
        d_model=args.d_model,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        d_ff=args.d_ff,
        dropout=args.dropout,
        pooling_modes=tuple(["evidence_only_pooling"] if args.smoke_test else args.variants),
        results_path=args.output_csv,
        smoke_test=args.smoke_test,
        require_gpu=args.require_gpu,
        seed=args.seed,
    )

    df = run_experiment(cfg)
    print(df.to_string(index=False))
    md_path = args.output_csv.replace('.csv', '.md')
    with open(md_path, 'w') as f:
        f.write(format_results_markdown(df, cfg))
    print(f"Saved results to {args.output_csv} and {md_path}")
