"""
trainer.py — Training loop for the consistency loss experiment.

Original variants (v1):
  1. consistency_loss        : LM loss + consistency loss on explanation pooling
  2. no_consistency_loss     : LM loss only (no consistency head gradient)
  3. claim_only_pooling      : LM loss + consistency loss on CLAIM token pooling (negative ctrl)
  4. random_label_consistency: LM loss + consistency loss with shuffled labels (negative ctrl)

New stronger ablation variants (v2):
  5. no_claim_to_claim_attention  : like consistency_loss but claim tokens cannot attend other claims
  6. claims_from_explanation_only : claim tokens can ONLY attend explanation tokens (not code)
  7. surface_bottleneck_consistency: consistency via LM logit softmax distributions, not hidden states
  8. surface_bottleneck_no_expl_lm: surface bottleneck + no LM loss on mismatched explanation tokens

Training config:
  Full:  20 epochs, batch_size=32, lr=5e-5
  Smoke: configurable (default 3 epochs, batch_size=8)
"""

import os
import copy
import time
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
import numpy as np

from dataset import (
    Example, build_dataset, split_dataset, build_tokenizer,
    make_target_sequence, SimpleTokenizer,
    CLAIM_OPEN, CLAIM_CLOSE, SEP_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN
)
from model import (
    ConsistencyTransformer, TransformerConfig,
    compute_lm_loss, compute_lm_loss_masked,
    compute_consistency_loss, total_loss,
    build_explanation_claim_mask,
    build_no_claim_to_claim_mask,
    build_claims_from_explanation_only_mask,
)


# ──────────────────────────────────────────────────────────────────────────────
# 1. Training Config
# ──────────────────────────────────────────────────────────────────────────────

@dataclass
class TrainConfig:
    # Full config (target)
    n_examples:    int   = 3000
    val_size:      int   = 500
    n_epochs:      int   = 20
    batch_size:    int   = 32
    lr:            float = 5e-5
    lambda_consistency: float = 1.0
    seed:          int   = 42
    checkpoint_every: int = 5
    model_size:    str   = "small"   # "smoke" | "small" | "gpt2_small"
    max_seq_len:   int   = 256
    output_dir:    str   = "outputs"
    # Smoke overrides (set by CLI)
    smoke:         bool  = False
    smoke_n:       int   = 300
    smoke_epochs:  int   = 5
    smoke_batch:   int   = 8
    smoke_model:   str   = "smoke"
    smoke_seq_len: int   = 96

    def effective(self):
        """Return effective config values after applying smoke overrides."""
        if self.smoke:
            return dict(
                n_examples=self.smoke_n,
                val_size=min(100, self.smoke_n // 3),
                n_epochs=self.smoke_epochs,
                batch_size=self.smoke_batch,
                model_size=self.smoke_model,
                max_seq_len=self.smoke_seq_len,
            )
        return dict(
            n_examples=self.n_examples,
            val_size=self.val_size,
            n_epochs=self.n_epochs,
            batch_size=self.batch_size,
            model_size=self.model_size,
            max_seq_len=self.max_seq_len,
        )


# ──────────────────────────────────────────────────────────────────────────────
# 2. PyTorch Dataset
# ──────────────────────────────────────────────────────────────────────────────

class CodeExplanationDataset(Dataset):
    """
    Tokenizes examples into fixed-length sequences.
    Records boolean masks for explanation tokens and claim tokens.
    """

    def __init__(
        self,
        examples: List[Example],
        tokenizer: SimpleTokenizer,
        max_len: int = 256,
    ):
        self.examples  = examples
        self.tok       = tokenizer
        self.max_len   = max_len
        self.data      = [self._process(ex) for ex in examples]

    def _process(self, ex: Example) -> dict:
        seq = make_target_sequence(ex)
        ids = self.tok.encode(seq)

        # Find SEP position (start of explanation)
        sep_id       = self.tok.sep_id
        claim_open_id = self.tok.claim_open_id
        claim_close_id = self.tok.claim_close_id

        sep_pos  = ids.index(sep_id) if sep_id in ids else 0
        # Explanation tokens: after <sep> up to (but not including) first <claim>
        first_claim_pos = next((i for i, t in enumerate(ids) if t == claim_open_id), len(ids))

        expl_mask  = [False] * len(ids)
        claim_mask = [False] * len(ids)
        in_claim   = False

        for i, tok_id in enumerate(ids):
            if i > sep_pos and i < first_claim_pos:
                expl_mask[i] = True
            if tok_id == claim_open_id:
                in_claim = True
            if in_claim:
                claim_mask[i] = True
            if tok_id == claim_close_id:
                in_claim = False

        # Truncate / pad
        ids       = ids[:self.max_len]
        expl_mask = expl_mask[:self.max_len]
        claim_mask = claim_mask[:self.max_len]

        pad_len = self.max_len - len(ids)
        ids        += [self.tok.pad_id] * pad_len
        expl_mask  += [False] * pad_len
        claim_mask += [False] * pad_len

        return {
            "input_ids":    torch.tensor(ids,        dtype=torch.long),
            "expl_mask":    torch.tensor(expl_mask,  dtype=torch.bool),
            "claim_mask":   torch.tensor(claim_mask, dtype=torch.bool),
            "time_label":   torch.tensor(ex.time_complexity_idx,  dtype=torch.long),
            "space_label":  torch.tensor(ex.space_complexity_idx, dtype=torch.long),
            "correct_label": torch.tensor(ex.correctness,          dtype=torch.long),
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# ──────────────────────────────────────────────────────────────────────────────
# 3. Variant specification
# ──────────────────────────────────────────────────────────────────────────────

# Original variants (v1) — kept intact for comparison
VARIANTS_V1 = [
    "consistency_loss",
    "no_consistency_loss",
    "claim_only_pooling",
    "random_label_consistency",
]

# New stronger ablation variants (v2)
VARIANTS_V2 = [
    "no_claim_to_claim_attention",
    "claims_from_explanation_only",
    "surface_bottleneck_consistency",
    "surface_bottleneck_no_expl_lm",
]

# All variants — used as default when no --variants flag is passed
VARIANTS = VARIANTS_V1 + VARIANTS_V2

# Variants that require surface bottleneck heads (SurfaceBottleneckHeads)
SURFACE_VARIANTS = {"surface_bottleneck_consistency", "surface_bottleneck_no_expl_lm"}

# Variants that disable LM loss on explanation tokens
NO_EXPL_LM_VARIANTS = {"surface_bottleneck_no_expl_lm"}


def get_variant_flags(variant: str) -> dict:
    """
    Return a dict of boolean flags that control training behaviour for a variant.

    Flags:
      use_consistency   : whether consistency loss is active
      pool_claims       : if True, pool claim tokens for consistency (negative ctrl)
      random_labels     : if True, shuffle labels (negative ctrl)
      use_surface_heads : consistency signal from LM logit distributions
      mask_type         : which attention mask to build
                            "base"                 — original expl→claim block
                            "no_c2c"               — base + block claim→claim
                            "claims_expl_only"     — claims attend explanation only
      mask_expl_lm_loss : if True, disable LM loss on explanation token positions
    """
    return {
        "use_consistency":   variant != "no_consistency_loss",
        "pool_claims":       variant == "claim_only_pooling",
        "random_labels":     variant == "random_label_consistency",
        "use_surface_heads": variant in SURFACE_VARIANTS,
        "mask_type":         (
            "no_c2c"           if variant == "no_claim_to_claim_attention" else
            "claims_expl_only" if variant == "claims_from_explanation_only" else
            "base"
        ),
        "mask_expl_lm_loss": variant in NO_EXPL_LM_VARIANTS,
    }


# ──────────────────────────────────────────────────────────────────────────────
# 4. Single-epoch train / eval
# ──────────────────────────────────────────────────────────────────────────────

def run_epoch(
    model: ConsistencyTransformer,
    loader: DataLoader,
    optimizer: Optional[torch.optim.Optimizer],
    cfg: TrainConfig,
    variant_flags: dict,
    device: torch.device,
    train: bool = True,
    rng: Optional[random.Random] = None,
) -> dict:
    model.train(train)
    total_lm   = 0.0
    total_cons = 0.0
    total_main = 0.0
    n_batches  = 0

    # Accumulate predictions for classifier accuracy
    all_time_preds, all_time_labels   = [], []
    all_space_preds, all_space_labels = [], []
    all_corr_preds, all_corr_labels   = [], []

    ctx = torch.no_grad() if not train else torch.enable_grad()

    use_consistency   = variant_flags["use_consistency"]
    pool_claims       = variant_flags["pool_claims"]
    random_labels     = variant_flags["random_labels"]
    mask_type         = variant_flags.get("mask_type", "base")
    mask_expl_lm_loss = variant_flags.get("mask_expl_lm_loss", False)

    pad_id = cfg.pad_id if hasattr(cfg, 'pad_id') else 0

    with ctx:
        for batch in loader:
            input_ids    = batch["input_ids"].to(device)
            expl_mask    = batch["expl_mask"].to(device)
            claim_mask   = batch["claim_mask"].to(device)
            time_labels  = batch["time_label"].to(device)
            space_labels = batch["space_label"].to(device)
            corr_labels  = batch["correct_label"].to(device)

            # Random label shuffle (negative control)
            if random_labels and train:
                B = time_labels.shape[0]
                perm = torch.randperm(B, device=device)
                time_labels  = time_labels[perm]
                space_labels = space_labels[perm]
                corr_labels  = corr_labels[perm]

            # ── Build attention mask based on variant ────────────────────────────
            if mask_type == "no_c2c":
                # Base expl→claim block + claim→claim block
                base_mask = build_explanation_claim_mask(expl_mask, claim_mask)
                c2c_mask  = build_no_claim_to_claim_mask(claim_mask)
                extra_mask = base_mask + c2c_mask
            elif mask_type == "claims_expl_only":
                # Strict: claims only attend explanation tokens
                extra_mask = build_claims_from_explanation_only_mask(expl_mask, claim_mask)
            else:
                # Original: block expl→claim (and causal handles the rest)
                extra_mask = build_explanation_claim_mask(expl_mask, claim_mask)

            lm_logits, t_logits, s_logits, c_logits = model(
                input_ids,
                explanation_mask=expl_mask,
                claim_mask=claim_mask,
                pool_claims=pool_claims,
                extra_attn_mask=extra_mask,
            )

            # ── LM loss (optionally masked) ────────────────────────────────────
            if mask_expl_lm_loss:
                # Disable LM loss at explanation positions (they are mismatched)
                # Keep LM loss at: code, SEP, claim, EOS, PAD (pad handled by ignore_index)
                lm_active_mask = ~expl_mask  # (B, T): True where we DO compute LM loss
                lm_l = compute_lm_loss_masked(
                    lm_logits, input_ids, pad_id, lm_loss_mask=lm_active_mask
                )
            else:
                lm_l = compute_lm_loss(lm_logits, input_ids, pad_id)

            # ── Consistency loss ─────────────────────────────────────────────
            if use_consistency:
                cons_l = compute_consistency_loss(
                    t_logits, s_logits, c_logits,
                    time_labels, space_labels, corr_labels
                )
                main_l = total_loss(lm_l, cons_l, cfg.lambda_consistency)
            else:
                cons_l = torch.tensor(0.0, device=device)
                main_l = lm_l

            if train:
                optimizer.zero_grad()
                main_l.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            total_lm   += lm_l.item()
            total_cons += cons_l.item()
            total_main += main_l.item()
            n_batches  += 1

            # Classifier accuracy
            with torch.no_grad():
                all_time_preds.extend(t_logits.argmax(-1).cpu().tolist())
                all_time_labels.extend(time_labels.cpu().tolist())
                all_space_preds.extend(s_logits.argmax(-1).cpu().tolist())
                all_space_labels.extend(space_labels.cpu().tolist())
                all_corr_preds.extend(c_logits.argmax(-1).cpu().tolist())
                all_corr_labels.extend(corr_labels.cpu().tolist())

    def accuracy(preds, labels):
        if not preds:
            return 0.0
        return sum(p == l for p, l in zip(preds, labels)) / len(preds)

    return {
        "lm_loss":      total_lm / max(n_batches, 1),
        "cons_loss":    total_cons / max(n_batches, 1),
        "total_loss":   total_main / max(n_batches, 1),
        "time_acc":     accuracy(all_time_preds, all_time_labels),
        "space_acc":    accuracy(all_space_preds, all_space_labels),
        "correct_acc":  accuracy(all_corr_preds, all_corr_labels),
        "coupling_strength": (
            accuracy(all_time_preds, all_time_labels) +
            accuracy(all_space_preds, all_space_labels) +
            accuracy(all_corr_preds, all_corr_labels)
        ) / 3.0,
    }


# ──────────────────────────────────────────────────────────────────────────────
# 5. BLEU / ROUGE metrics (explanation correctness)
# ──────────────────────────────────────────────────────────────────────────────

def simple_bleu1(hypothesis: str, reference: str) -> float:
    """Unigram BLEU (precision-based) as a fast proxy."""
    hyp_tokens = hypothesis.lower().split()
    ref_tokens = set(reference.lower().split())
    if not hyp_tokens:
        return 0.0
    matches = sum(1 for t in hyp_tokens if t in ref_tokens)
    precision = matches / len(hyp_tokens)
    # brevity penalty
    bp = min(1.0, len(hyp_tokens) / max(len(reference.split()), 1))
    return precision * bp


def rouge_l_score(hypothesis: str, reference: str) -> float:
    """Simplified ROUGE-L (LCS-based F1)."""
    def lcs_len(a, b):
        m, n = len(a), len(b)
        if m == 0 or n == 0:
            return 0
        # DP with O(min(m,n)) space
        prev = [0] * (n + 1)
        for i in range(1, m + 1):
            curr = [0] * (n + 1)
            for j in range(1, n + 1):
                if a[i-1] == b[j-1]:
                    curr[j] = prev[j-1] + 1
                else:
                    curr[j] = max(prev[j], curr[j-1])
            prev = curr
        return prev[n]

    hyp = hypothesis.lower().split()
    ref = reference.lower().split()
    if not hyp or not ref:
        return 0.0
    lcs = lcs_len(hyp, ref)
    precision = lcs / len(hyp)
    recall    = lcs / len(ref)
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)


# ──────────────────────────────────────────────────────────────────────────────
# 6. Generation (greedy decode)
# ──────────────────────────────────────────────────────────────────────────────

@torch.no_grad()
def greedy_decode(
    model: ConsistencyTransformer,
    tokenizer: SimpleTokenizer,
    code_snippet: str,
    max_new_tokens: int = 60,
    device: torch.device = torch.device("cpu"),
) -> str:
    """Greedy autoregressive decode given a code prefix."""
    model.eval()
    prefix = f"{BOS_TOKEN} {code_snippet.strip()} {SEP_TOKEN}"
    ids = tokenizer.encode(prefix)
    ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)

    generated = []
    max_len = model.cfg.max_seq_len

    for _ in range(max_new_tokens):
        if ids_t.shape[1] >= max_len:
            break
        lm_logits, _, _, _ = model(ids_t)
        next_id = lm_logits[0, -1, :].argmax(-1).item()
        if next_id == tokenizer.eos_id:
            break
        generated.append(next_id)
        ids_t = torch.cat([ids_t, torch.tensor([[next_id]], device=device)], dim=1)

    return tokenizer.decode(generated)


# ──────────────────────────────────────────────────────────────────────────────
# 7. Counterfactual swap metric
# ──────────────────────────────────────────────────────────────────────────────

@torch.no_grad()
def counterfactual_swap_influence(
    model: ConsistencyTransformer,
    examples: List[Example],
    tokenizer: SimpleTokenizer,
    dataset: CodeExplanationDataset,
    n_pairs: int = 50,
    device: torch.device = torch.device("cpu"),
) -> float:
    """
    For n_pairs of examples (A, B) with different time complexity labels:
      1. Forward pass A → get pooled explanation hidden states → classify
      2. Forward pass B → swap time labels → recompute classifier logits from A's pool
         (proxy: measure if A's hidden states predict A's label better than B's label)
    Returns: mean accuracy difference (A correctly classified vs B's label applied to A).
    This is a simplified proxy for counterfactual swap influence.
    """
    model.eval()
    # Collect pairs with different time complexity
    pairs = []
    n = len(examples)
    attempts = 0
    rng = random.Random(99)
    while len(pairs) < n_pairs and attempts < n * 10:
        i = rng.randint(0, n - 1)
        j = rng.randint(0, n - 1)
        if i != j and examples[i].time_complexity_idx != examples[j].time_complexity_idx:
            pairs.append((i, j))
        attempts += 1

    if not pairs:
        return 0.5

    correct_own = 0
    correct_swapped = 0

    for i_idx, j_idx in pairs:
        item_i = dataset[i_idx]
        input_ids = item_i["input_ids"].unsqueeze(0).to(device)
        expl_mask  = item_i["expl_mask"].unsqueeze(0).to(device)
        claim_mask = item_i["claim_mask"].unsqueeze(0).to(device)
        extra_mask = build_explanation_claim_mask(expl_mask, claim_mask)

        _, t_logits, _, _ = model(input_ids, expl_mask, claim_mask, extra_attn_mask=extra_mask)
        pred = t_logits.argmax(-1).item()
        true_label_i = examples[i_idx].time_complexity_idx
        true_label_j = examples[j_idx].time_complexity_idx

        if pred == true_label_i:
            correct_own += 1
        if pred == true_label_j:
            correct_swapped += 1

    # Swap influence: how much better is the model at predicting i's own label vs j's swapped label
    swap_influence = (correct_own - correct_swapped) / max(len(pairs), 1)
    return float(swap_influence)


# ──────────────────────────────────────────────────────────────────────────────
# 8. Claim emission accuracy
# ──────────────────────────────────────────────────────────────────────────────

@torch.no_grad()
def claim_emission_accuracy(
    model: ConsistencyTransformer,
    dataset: CodeExplanationDataset,
    tokenizer: SimpleTokenizer,
    device: torch.device,
    n_samples: int = 50,
) -> float:
    """
    Generates output for n_samples examples and checks whether emitted
    claim tokens match ground-truth claim tokens.
    """
    model.eval()
    n = len(dataset)
    indices = list(range(min(n_samples, n)))
    correct = 0
    total   = 0

    for idx in indices:
        ex = dataset.examples[idx]
        gen = greedy_decode(model, tokenizer, ex.code_snippet,
                            max_new_tokens=80, device=device)
        # Check for correct claim tokens in output
        for claim_str, truth in [
            (f"time_complexity={ex.time_complexity}", True),
            (f"space_complexity={ex.space_complexity}", True),
            (f"correctness={ex.correctness}", True),
        ]:
            if claim_str in gen:
                correct += 1
            total += 1

    return correct / max(total, 1)


# ──────────────────────────────────────────────────────────────────────────────
# 9. Full validation pass
# ──────────────────────────────────────────────────────────────────────────────

def validate(
    model: ConsistencyTransformer,
    val_dataset: CodeExplanationDataset,
    val_loader: DataLoader,
    cfg: TrainConfig,
    variant_flags: dict,
    device: torch.device,
    n_gen_samples: int = 30,
    n_swap_pairs: int = 30,
) -> dict:
    """Full validation: classifier metrics + BLEU/ROUGE + swap influence + claim accuracy."""

    # Classifier + LM metrics
    epoch_metrics = run_epoch(
        model, val_loader, optimizer=None, cfg=cfg,
        variant_flags=variant_flags, device=device, train=False
    )

    # Explanation correctness (BLEU-1 + ROUGE-L) on sample
    tokenizer = val_dataset.tok
    examples  = val_dataset.examples
    n_gen = min(n_gen_samples, len(examples))

    bleu_scores, rouge_scores = [], []
    for idx in range(n_gen):
        ex = examples[idx]
        gen = greedy_decode(model, tokenizer, ex.code_snippet,
                            max_new_tokens=60, device=device)
        bleu_scores.append(simple_bleu1(gen, ex.true_explanation))
        rouge_scores.append(rouge_l_score(gen, ex.true_explanation))

    epoch_metrics["bleu1"]   = float(np.mean(bleu_scores))  if bleu_scores  else 0.0
    epoch_metrics["rouge_l"] = float(np.mean(rouge_scores)) if rouge_scores else 0.0

    # Counterfactual swap influence
    epoch_metrics["swap_influence"] = counterfactual_swap_influence(
        model, examples, tokenizer, val_dataset,
        n_pairs=n_swap_pairs, device=device
    )

    # Claim emission accuracy
    epoch_metrics["claim_accuracy"] = claim_emission_accuracy(
        model, val_dataset, tokenizer, device, n_samples=n_gen
    )

    return epoch_metrics


# ──────────────────────────────────────────────────────────────────────────────
# 10. Main training loop for one variant
# ──────────────────────────────────────────────────────────────────────────────

def train_variant(
    variant: str,
    train_examples: List[Example],
    val_examples:   List[Example],
    tokenizer: SimpleTokenizer,
    cfg: TrainConfig,
    device: torch.device,
    eff: dict,
) -> List[dict]:
    """Train one variant and return list of per-epoch metric dicts."""

    print(f"\n{'='*60}")
    print(f"  Training variant: {variant}")
    print(f"  Epochs: {eff['n_epochs']}  Batch: {eff['batch_size']}  Model: {eff['model_size']}")
    print(f"{'='*60}")

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    variant_flags = get_variant_flags(variant)

    # Build datasets
    train_ds = CodeExplanationDataset(train_examples, tokenizer, eff["max_seq_len"])
    val_ds   = CodeExplanationDataset(val_examples,   tokenizer, eff["max_seq_len"])

    train_loader = DataLoader(train_ds, batch_size=eff["batch_size"], shuffle=True,  drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=eff["batch_size"], shuffle=False, drop_last=False)

    # Build model
    model_size = eff["model_size"]
    if model_size == "smoke":
        model_cfg = TransformerConfig.smoke(vocab_size=tokenizer.vocab_size, pad_id=tokenizer.pad_id)
    elif model_size == "small":
        model_cfg = TransformerConfig.small(vocab_size=tokenizer.vocab_size, pad_id=tokenizer.pad_id)
    elif model_size == "gpt2_small":
        model_cfg = TransformerConfig.gpt2_small(vocab_size=tokenizer.vocab_size, pad_id=tokenizer.pad_id)
    else:
        model_cfg = TransformerConfig.smoke(vocab_size=tokenizer.vocab_size, pad_id=tokenizer.pad_id)

    model_cfg.max_seq_len = eff["max_seq_len"]
    use_surface = variant_flags.get("use_surface_heads", False)
    model = ConsistencyTransformer(model_cfg, use_surface_heads=use_surface).to(device)
    print(f"  Parameters: {model.num_parameters():,}")
    if use_surface:
        print(f"  [Surface bottleneck: consistency via LM logit distributions]")
    if variant_flags.get("mask_expl_lm_loss"):
        print(f"  [LM loss masked: explanation token positions excluded]")
    if variant_flags.get("mask_type") == "no_c2c":
        print(f"  [Attention mask: claim→claim attention blocked]")
    elif variant_flags.get("mask_type") == "claims_expl_only":
        print(f"  [Attention mask: claims attend explanation tokens only]")

    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.01)

    # Checkpoint dir
    ckpt_dir = os.path.join(cfg.output_dir, "checkpoints", variant)
    os.makedirs(ckpt_dir, exist_ok=True)

    epoch_records = []
    n_epochs = eff["n_epochs"]

    for epoch in range(1, n_epochs + 1):
        t0 = time.time()

        train_metrics = run_epoch(
            model, train_loader, optimizer, cfg,
            variant_flags, device, train=True
        )
        val_metrics = validate(
            model, val_ds, val_loader, cfg,
            variant_flags, device,
            n_gen_samples=min(20, len(val_examples)),
            n_swap_pairs=min(20, len(val_examples) // 2),
        )

        elapsed = time.time() - t0
        record = {
            "variant": variant,
            "epoch":   epoch,
            "elapsed": elapsed,
            **{f"train_{k}": v for k, v in train_metrics.items()},
            **{f"val_{k}":   v for k, v in val_metrics.items()},
        }
        epoch_records.append(record)

        print(
            f"  Epoch {epoch:02d}/{n_epochs} | "
            f"train_loss={train_metrics['total_loss']:.4f} | "
            f"val_coupling={val_metrics['coupling_strength']:.3f} | "
            f"val_bleu={val_metrics['bleu1']:.3f} | "
            f"val_swap={val_metrics['swap_influence']:.3f} | "
            f"val_claim_acc={val_metrics['claim_accuracy']:.3f} | "
            f"{elapsed:.1f}s"
        )

        # Save checkpoint every checkpoint_every epochs
        if epoch % cfg.checkpoint_every == 0 or epoch == n_epochs:
            ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch:03d}.pt")
            torch.save({
                "epoch":           epoch,
                "model_state":     model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "metrics":         record,
                "model_cfg":       model_cfg,
                "use_surface_heads": use_surface,
            }, ckpt_path)
            print(f"  Saved checkpoint: {ckpt_path}")

    return epoch_records


# ──────────────────────────────────────────────────────────────────────────────
# 11. Qualitative examples collection
# ──────────────────────────────────────────────────────────────────────────────

def collect_qualitative_examples(
    epoch_1_gens: List[dict],
    epoch_final_gens: List[dict],
) -> List[dict]:
    """
    Merge epoch-1 and epoch-final generations for same examples.
    Returns list of dicts with keys: idx, code, true_expl, mismatched_expl,
    epoch_1_gen, epoch_final_gen, claims.
    """
    final_by_idx = {d["idx"]: d for d in epoch_final_gens}
    result = []
    for d1 in epoch_1_gens:
        idx = d1["idx"]
        if idx in final_by_idx:
            df = final_by_idx[idx]
            result.append({
                "idx":             idx,
                "code":            d1["code"],
                "true_expl":       d1["true_expl"],
                "mismatched_expl": d1["mismatched_expl"],
                "epoch_1_gen":     d1["gen"],
                "epoch_final_gen": df["gen"],
                "time_complexity": d1["time_complexity"],
                "space_complexity": d1["space_complexity"],
                "correctness":     d1["correctness"],
            })
    return result[:15]


def generate_epoch_samples(
    model: ConsistencyTransformer,
    examples: List[Example],
    tokenizer: SimpleTokenizer,
    device: torch.device,
    n: int = 15,
) -> List[dict]:
    results = []
    for ex in examples[:n]:
        gen = greedy_decode(model, tokenizer, ex.code_snippet,
                            max_new_tokens=60, device=device)
        results.append({
            "idx":             ex.idx,
            "code":            ex.code_snippet[:200],
            "true_expl":       ex.true_explanation,
            "mismatched_expl": ex.mismatched_explanation,
            "gen":             gen,
            "time_complexity": ex.time_complexity,
            "space_complexity": ex.space_complexity,
            "correctness":     ex.correctness,
        })
    return results


if __name__ == "__main__":
    # Quick smoke check
    from dataset import build_dataset, split_dataset, build_tokenizer
    examples = build_dataset(n=100, seed=42)
    train_ex, val_ex = split_dataset(examples, val_size=20)
    tok = build_tokenizer(examples)
    print(f"Vocab size: {tok.vocab_size}")

    cfg = TrainConfig(smoke=True, smoke_n=100, smoke_epochs=2, smoke_batch=8)
    eff = cfg.effective()
    device = torch.device("cpu")
    records = train_variant("consistency_loss", train_ex, val_ex, tok, cfg, device, eff)
    print(f"Done. {len(records)} epoch records.")
