"""
claim_consistency_experiment.py
================================
Minimal reproducible experiment for validating claim-consistency coupling
in a small decoder-only transformer.

Structure
---------
1. Config dataclass
2. Synthetic dataset generator
3. Small GPT-2-style decoder-only Transformer
4. Training loop (4 variants)
5. Evaluation: claim accuracy, counterfactual swap, shuffled-pairing control
6. Results table

Run as a script for a quick smoke test:
    python claim_consistency_experiment.py
"""

from __future__ import annotations

import copy
import math
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ---------------------------------------------------------------------------
# 1. Config
# ---------------------------------------------------------------------------

@dataclass
class ExperimentConfig:
    # Dataset
    num_latent_states: int = 8          # 8-16
    num_rationale_templates: int = 4    # 3-5 per latent state
    vocab_size: int = 128               # small synthetic vocabulary
    prompt_len: int = 4                 # tokens before rationale
    rationale_len: int = 8              # tokens in rationale span
    claim_len: int = 2                  # tokens for claim label
    num_train_samples: int = 512
    num_eval_samples: int = 128
    num_shuffled_samples: int = 128     # shuffled-pairing control
    seed: int = 42

    # Model
    d_model: int = 64
    n_heads: int = 4
    n_layers: int = 2
    d_ff: int = 128
    dropout: float = 0.1
    max_seq_len: int = 64

    # Training
    batch_size: int = 32
    num_epochs: int = 5                 # small for smoke test
    lr: float = 3e-4
    consistency_loss_weight: float = 0.5
    pooling_modes: Tuple[str, ...] = (
        "no_consistency_loss",
        "rationale_only",
        "full_sequence",
        "earlier_token_only",
    )

    # Misc
    device: str = "cpu"
    results_path: str = "results_comparison.csv"


# ---------------------------------------------------------------------------
# 2. Synthetic dataset
# ---------------------------------------------------------------------------

# We build a tiny vocabulary:
#   0          = PAD
#   1          = BOS
#   2..9       = prompt tokens (generic)
#   10..79     = rationale tokens  (10 per latent state * 8 latent states)
#   80..95     = claim label tokens (2 per latent state * 8 latent states)
#   96..127    = misc / separator tokens

_SEP_TOKEN = 96     # separator between prompt, rationale, claim sections
_BOS_TOKEN = 1
_PAD_TOKEN = 0

def _build_vocabulary(cfg: ExperimentConfig):
    """Return token ranges for each latent state."""
    rationale_base = 10
    claim_base = 80
    tokens_per_state_rationale = 10  # pool of tokens for rationale templates
    tokens_per_state_claim = 2       # deterministic claim tokens
    assert cfg.num_latent_states <= 8, "Extend vocab if > 8 latent states"
    vocab = {}
    for s in range(cfg.num_latent_states):
        rationale_tokens = list(range(
            rationale_base + s * tokens_per_state_rationale,
            rationale_base + s * tokens_per_state_rationale + tokens_per_state_rationale
        ))
        claim_tokens = list(range(
            claim_base + s * tokens_per_state_claim,
            claim_base + s * tokens_per_state_claim + tokens_per_state_claim
        ))
        templates = []
        rng = np.random.default_rng(s)  # deterministic per state
        for _ in range(cfg.num_rationale_templates):
            tmpl = rng.choice(rationale_tokens, size=cfg.rationale_len, replace=True).tolist()
            templates.append(tmpl)
        vocab[s] = {
            "rationale_templates": templates,
            "claim_tokens": claim_tokens,
        }
    return vocab


def _make_prompt_tokens(cfg: ExperimentConfig, rng: np.random.Generator) -> List[int]:
    return rng.integers(2, 10, size=cfg.prompt_len).tolist()


def make_sample(
    cfg: ExperimentConfig,
    vocab: dict,
    latent_state: int,
    rng: np.random.Generator,
    shuffled: bool = False,
    swap_rationale_from: Optional[int] = None,
) -> dict:
    """
    Build one sample.

    Layout: [BOS] [prompt...] [SEP] [rationale...] [SEP] [claim...]

    Returns a dict with:
        token_ids        : (seq_len,) int64
        targets          : (seq_len,) int64  next-token shifted
        rationale_mask   : (seq_len,) bool   True at rationale positions
        full_seq_mask    : (seq_len,) bool   True everywhere (except PAD)
        earlier_tok_mask : (seq_len,) bool   True at prompt+rationale positions
        latent_state     : int
        claim_label      : int  (index of first claim token in vocab)
        is_shuffled      : bool
    """
    # choose rationale template
    if swap_rationale_from is not None:
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[swap_rationale_from]["rationale_templates"][rat_tmpl_idx]
    elif shuffled:
        # pick a random OTHER state's rationale
        other = rng.integers(0, cfg.num_latent_states)
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[other]["rationale_templates"][rat_tmpl_idx]
    else:
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[latent_state]["rationale_templates"][rat_tmpl_idx]

    claim_tokens = vocab[latent_state]["claim_tokens"]
    prompt_tokens = _make_prompt_tokens(cfg, rng)

    # full sequence
    seq = [_BOS_TOKEN] + prompt_tokens + [_SEP_TOKEN] + rationale_tokens + [_SEP_TOKEN] + claim_tokens
    seq_len = len(seq)  # = 1 + prompt_len + 1 + rationale_len + 1 + claim_len

    token_ids = torch.tensor(seq, dtype=torch.long)
    # next-token targets: shift by 1, last target = PAD
    targets = torch.cat([token_ids[1:], torch.tensor([_PAD_TOKEN], dtype=torch.long)])

    # masks (positional)
    prompt_start = 1                         # after BOS
    prompt_end = 1 + cfg.prompt_len          # exclusive
    rat_start = prompt_end + 1               # after first SEP
    rat_end = rat_start + cfg.rationale_len  # exclusive
    claim_start = rat_end + 1                # after second SEP

    rationale_mask = torch.zeros(seq_len, dtype=torch.bool)
    rationale_mask[rat_start:rat_end] = True

    full_seq_mask = torch.ones(seq_len, dtype=torch.bool)

    earlier_tok_mask = torch.zeros(seq_len, dtype=torch.bool)
    earlier_tok_mask[prompt_start:rat_end] = True  # prompt + rationale, not claim

    return {
        "token_ids": token_ids,
        "targets": targets,
        "rationale_mask": rationale_mask,
        "full_seq_mask": full_seq_mask,
        "earlier_tok_mask": earlier_tok_mask,
        "latent_state": latent_state,
        "claim_label": latent_state,  # one label per latent state (index)
        "is_shuffled": int(shuffled),
    }


class ClaimConsistencyDataset(Dataset):
    def __init__(
        self,
        cfg: ExperimentConfig,
        vocab: dict,
        n_samples: int,
        shuffled: bool = False,
        seed_offset: int = 0,
    ):
        self.samples = []
        rng = np.random.default_rng(cfg.seed + seed_offset)
        for i in range(n_samples):
            ls = int(rng.integers(0, cfg.num_latent_states))
            s = make_sample(cfg, vocab, ls, rng, shuffled=shuffled)
            self.samples.append(s)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def collate_fn(batch):
    keys = batch[0].keys()
    out = {}
    for k in keys:
        v0 = batch[0][k]
        if isinstance(v0, torch.Tensor):
            out[k] = torch.stack([b[k] for b in batch])
        else:
            out[k] = torch.tensor([b[k] for b in batch], dtype=torch.long)
    return out


# ---------------------------------------------------------------------------
# 3. Model
# ---------------------------------------------------------------------------

class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: ExperimentConfig):
        super().__init__()
        assert cfg.d_model % cfg.n_heads == 0
        self.n_heads = cfg.n_heads
        self.head_dim = cfg.d_model // cfg.n_heads
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # each (B, T, H, D)
        q = q.transpose(1, 2)    # (B, H, T, D)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        scale = math.sqrt(self.head_dim)
        att = (q @ k.transpose(-2, -1)) / scale  # (B, H, T, T)
        # causal mask
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
        att = att.masked_fill(mask == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.drop(att)
        y = (att @ v).transpose(1, 2).reshape(B, T, C)
        return self.proj(y)


class TransformerBlock(nn.Module):
    def __init__(self, cfg: ExperimentConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.ff = nn.Sequential(
            nn.Linear(cfg.d_model, cfg.d_ff),
            nn.GELU(),
            nn.Linear(cfg.d_ff, cfg.d_model),
            nn.Dropout(cfg.dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


class ClaimConsistencyTransformer(nn.Module):
    """
    Decoder-only transformer with an optional consistency classification head.

    pooling_mode:
        "no_consistency_loss"  - LM loss only, no consistency head used in training
        "rationale_only"       - pool hidden states over rationale mask
        "full_sequence"        - pool over entire non-pad sequence
        "earlier_token_only"   - pool over prompt + rationale positions only
    """
    def __init__(self, cfg: ExperimentConfig, pooling_mode: str = "no_consistency_loss"):
        super().__init__()
        self.cfg = cfg
        self.pooling_mode = pooling_mode

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

        # consistency head (always built, selectively used)
        self.consistency_head = nn.Linear(cfg.d_model, cfg.num_latent_states)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)

    def forward(
        self,
        token_ids: torch.Tensor,          # (B, T)
        targets: Optional[torch.Tensor],  # (B, T)
        rationale_mask: torch.Tensor,     # (B, T)
        full_seq_mask: torch.Tensor,      # (B, T)
        earlier_tok_mask: torch.Tensor,   # (B, T)
        claim_labels: torch.Tensor,       # (B,)
        consistency_weight: float = 0.5,
    ) -> Tuple[torch.Tensor, dict]:
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))

        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)  # (B, T, d_model)

        # LM loss
        logits = self.lm_head(x)  # (B, T, V)
        lm_loss = torch.tensor(0.0)
        if targets is not None:
            lm_loss = F.cross_entropy(
                logits.reshape(-1, self.cfg.vocab_size),
                targets.reshape(-1),
                ignore_index=_PAD_TOKEN,
            )

        # Consistency loss
        cons_loss = torch.tensor(0.0)
        cons_logits = None
        if self.pooling_mode != "no_consistency_loss":
            pool_mask = self._get_pool_mask(
                rationale_mask, full_seq_mask, earlier_tok_mask
            )  # (B, T)
            # mean pool over selected positions
            denom = pool_mask.float().sum(dim=1, keepdim=True).clamp(min=1)
            pooled = (x * pool_mask.float().unsqueeze(-1)).sum(dim=1) / denom  # (B, d)
            cons_logits = self.consistency_head(pooled)  # (B, num_latent_states)
            cons_loss = F.cross_entropy(cons_logits, claim_labels)

        total_loss = lm_loss + consistency_weight * cons_loss

        return total_loss, {
            "lm_loss": lm_loss.item(),
            "cons_loss": cons_loss.item(),
            "total_loss": total_loss.item(),
            "hidden_states": x,
            "lm_logits": logits,
            "cons_logits": cons_logits,
        }

    def _get_pool_mask(self, rationale_mask, full_seq_mask, earlier_tok_mask):
        if self.pooling_mode == "rationale_only":
            return rationale_mask
        elif self.pooling_mode == "full_sequence":
            return full_seq_mask
        elif self.pooling_mode == "earlier_token_only":
            return earlier_tok_mask
        else:
            return full_seq_mask  # fallback

    @torch.no_grad()
    def pool_hidden_rationale(self, token_ids, rationale_mask):
        """Pool hidden states over rationale mask. Used for eval regardless of pooling_mode."""
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        denom = rationale_mask.float().sum(dim=1, keepdim=True).clamp(min=1)
        pooled = (x * rationale_mask.float().unsqueeze(-1)).sum(dim=1) / denom
        return self.consistency_head(pooled)  # (B, num_latent_states)

    @torch.no_grad()
    def generate_claim(
        self,
        token_ids: torch.Tensor,  # (B, prefix_len) = prompt+rationale portion
        max_new: int = 4,
    ) -> torch.Tensor:
        """Greedy generation of claim tokens after prompt+rationale."""
        generated = token_ids.clone()
        for _ in range(max_new):
            B, T = generated.shape
            pos = torch.arange(T, device=generated.device).unsqueeze(0)
            x = self.drop(self.tok_emb(generated) + self.pos_emb(pos))
            for block in self.blocks:
                x = block(x)
            x = self.ln_f(x)
            logits = self.lm_head(x[:, -1, :])  # (B, V)
            next_tok = logits.argmax(dim=-1, keepdim=True)  # (B, 1)
            generated = torch.cat([generated, next_tok], dim=1)
        return generated[:, token_ids.shape[1]:]  # return only new tokens


# ---------------------------------------------------------------------------
# 4. Training
# ---------------------------------------------------------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def train_one_variant(
    cfg: ExperimentConfig,
    pooling_mode: str,
    train_loader: DataLoader,
    eval_loader: DataLoader,
    shuffled_loader: DataLoader,
    vocab: dict,
) -> Tuple[ClaimConsistencyTransformer, List[dict]]:
    set_seed(cfg.seed)
    model = ClaimConsistencyTransformer(cfg, pooling_mode=pooling_mode).to(cfg.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    history = []

    for epoch in range(cfg.num_epochs):
        model.train()
        total_lm, total_cons, n_batches = 0.0, 0.0, 0
        for batch in train_loader:
            batch = {k: v.to(cfg.device) for k, v in batch.items()}
            optimizer.zero_grad()
            loss, info = model(
                batch["token_ids"],
                batch["targets"],
                batch["rationale_mask"],
                batch["full_seq_mask"],
                batch["earlier_tok_mask"],
                batch["claim_label"],
                consistency_weight=cfg.consistency_loss_weight,
            )
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_lm += info["lm_loss"]
            total_cons += info["cons_loss"]
            n_batches += 1

        history.append({
            "epoch": epoch + 1,
            "lm_loss": total_lm / n_batches,
            "cons_loss": total_cons / n_batches,
        })

    return model, history


# ---------------------------------------------------------------------------
# 5. Evaluation helpers
# ---------------------------------------------------------------------------

def _seq_prefix(cfg, sample):
    """Return token_ids up to and including rationale (before claim SEP)."""
    # Layout: BOS(1) + prompt(4) + SEP(1) + rationale(8) + SEP(1) + claim(2)
    # prefix = everything up to but not including first claim token
    prefix_len = 1 + cfg.prompt_len + 1 + cfg.rationale_len + 1  # = 15
    return sample["token_ids"][:prefix_len]


def evaluate_claim_accuracy_generation(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    dataset: ClaimConsistencyDataset,
    vocab: dict,
) -> float:
    """Generate claims greedily and check if first generated token matches expected claim token."""
    model.eval()
    correct = 0
    total = 0
    # build claim_token_map: latent_state -> first claim token id
    claim_first_token = {s: vocab[s]["claim_tokens"][0] for s in range(cfg.num_latent_states)}

    for sample in dataset.samples:
        prefix = _seq_prefix(cfg, sample).unsqueeze(0).to(cfg.device)
        generated = model.generate_claim(prefix, max_new=cfg.claim_len)
        pred_first = generated[0, 0].item()
        expected = claim_first_token[sample["claim_label"]]
        if pred_first == expected:
            correct += 1
        total += 1
    return correct / total if total > 0 else 0.0


def evaluate_claim_accuracy_classifier(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    dataset: ClaimConsistencyDataset,
) -> float:
    """Classify claim from rationale-only pooled hidden states."""
    model.eval()
    correct = 0
    total = 0
    for sample in dataset.samples:
        tids = sample["token_ids"].unsqueeze(0).to(cfg.device)
        rmask = sample["rationale_mask"].unsqueeze(0).to(cfg.device)
        logits = model.pool_hidden_rationale(tids, rmask)
        pred = logits.argmax(dim=-1).item()
        if pred == sample["claim_label"]:
            correct += 1
        total += 1
    return correct / total if total > 0 else 0.0


def evaluate_counterfactual_swap(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    vocab: dict,
    n_samples: int = 64,
    seed_offset: int = 999,
) -> dict:
    """
    Counterfactual rationale swap test.

    For each sample:
      - original_latent: latent state that determines the claim
      - swapped_rationale: taken from a DIFFERENT latent state
      - Build sequence with original_latent's claim but swapped rationale
      - Measure:
          gen_follows_swap: generated claim matches swapped_rationale latent (vs original)
          cls_follows_swap: classifier (rationale-pool) follows swapped rationale
    """
    model.eval()
    rng = np.random.default_rng(cfg.seed + seed_offset)
    gen_follows_swap = 0
    cls_follows_swap = 0
    gen_follows_orig = 0
    cls_follows_orig = 0
    total = 0

    claim_first_token = {s: vocab[s]["claim_tokens"][0] for s in range(cfg.num_latent_states)}

    for _ in range(n_samples):
        orig_state = int(rng.integers(0, cfg.num_latent_states))
        swap_state = int(rng.integers(0, cfg.num_latent_states))
        while swap_state == orig_state:
            swap_state = int(rng.integers(0, cfg.num_latent_states))

        sample = make_sample(cfg, vocab, orig_state, rng, swap_rationale_from=swap_state)

        prefix = _seq_prefix(cfg, sample).unsqueeze(0).to(cfg.device)
        tids = sample["token_ids"].unsqueeze(0).to(cfg.device)
        rmask = sample["rationale_mask"].unsqueeze(0).to(cfg.device)

        # generation
        generated = model.generate_claim(prefix, max_new=cfg.claim_len)
        pred_gen = generated[0, 0].item()
        if pred_gen == claim_first_token[swap_state]:
            gen_follows_swap += 1
        if pred_gen == claim_first_token[orig_state]:
            gen_follows_orig += 1

        # classifier
        cls_logits = model.pool_hidden_rationale(tids, rmask)
        pred_cls = cls_logits.argmax(dim=-1).item()
        if pred_cls == swap_state:
            cls_follows_swap += 1
        if pred_cls == orig_state:
            cls_follows_orig += 1

        total += 1

    return {
        "gen_follows_swap_rate": gen_follows_swap / total,
        "gen_follows_orig_rate": gen_follows_orig / total,
        "cls_follows_swap_rate": cls_follows_swap / total,
        "cls_follows_orig_rate": cls_follows_orig / total,
        "n_samples": total,
    }


def evaluate_shuffled_pairing(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    shuffled_dataset: ClaimConsistencyDataset,
    vocab: dict,
) -> dict:
    """
    Evaluate on shuffled-pairing control (rationale-claim mismatch).
    Returns generation accuracy and classifier accuracy on shuffled data.
    Gen accuracy is expected to be ~1/num_latent_states if the model is confused.
    """
    gen_acc = evaluate_claim_accuracy_generation(cfg, model, shuffled_dataset, vocab=vocab)
    # Note: for generation we evaluate against the CLAIM label (not the shuffled rationale)
    # so a drop indicates the model was relying on rationale
    cls_acc = evaluate_claim_accuracy_classifier(cfg, model, shuffled_dataset)
    return {
        "shuffled_gen_acc": gen_acc,
        "shuffled_cls_acc": cls_acc,
    }


# ---------------------------------------------------------------------------
# 6. Full experiment runner
# ---------------------------------------------------------------------------

def run_experiment(cfg: Optional[ExperimentConfig] = None) -> pd.DataFrame:
    if cfg is None:
        cfg = ExperimentConfig()

    set_seed(cfg.seed)
    print(f"[INFO] Building vocabulary for {cfg.num_latent_states} latent states ...")
    vocab = _build_vocabulary(cfg)

    print("[INFO] Building datasets ...")
    train_ds = ClaimConsistencyDataset(cfg, vocab, cfg.num_train_samples, shuffled=False, seed_offset=0)
    eval_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_eval_samples,  shuffled=False, seed_offset=100)
    shuf_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_shuffled_samples, shuffled=True, seed_offset=200)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    eval_loader  = DataLoader(eval_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    shuf_loader  = DataLoader(shuf_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    results = []

    for pooling_mode in cfg.pooling_modes:
        print(f"\n{'='*60}")
        print(f"[TRAIN] variant={pooling_mode}")
        model, history = train_one_variant(
            cfg, pooling_mode, train_loader, eval_loader, shuf_loader, vocab
        )
        final_lm = history[-1]["lm_loss"]
        final_cons = history[-1]["cons_loss"]
        print(f"  Final LM loss={final_lm:.4f}  Cons loss={final_cons:.4f}")

        print(f"[EVAL]  variant={pooling_mode}")
        gen_acc = evaluate_claim_accuracy_generation(cfg, model, eval_ds, vocab)
        cls_acc = evaluate_claim_accuracy_classifier(cfg, model, eval_ds)
        cfact   = evaluate_counterfactual_swap(cfg, model, vocab, n_samples=64)
        shuf    = evaluate_shuffled_pairing(cfg, model, shuf_ds, vocab)

        row = {
            "variant": pooling_mode,
            "final_lm_loss": round(final_lm, 4),
            "final_cons_loss": round(final_cons, 4),
            "gen_claim_acc": round(gen_acc, 4),
            "cls_claim_acc (rationale_pool)": round(cls_acc, 4),
            "cfact_gen_follows_swap": round(cfact["gen_follows_swap_rate"], 4),
            "cfact_gen_follows_orig": round(cfact["gen_follows_orig_rate"], 4),
            "cfact_cls_follows_swap": round(cfact["cls_follows_swap_rate"], 4),
            "cfact_cls_follows_orig": round(cfact["cls_follows_orig_rate"], 4),
            "shuffled_gen_acc": round(shuf["shuffled_gen_acc"], 4),
            "shuffled_cls_acc": round(shuf["shuffled_cls_acc"], 4),
        }
        results.append(row)
        print("  " + str({k: v for k, v in row.items() if k != "variant"}))

    df = pd.DataFrame(results)
    df.to_csv(cfg.results_path, index=False)
    print(f"\n[DONE] Results saved to {cfg.results_path}")
    return df


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    # Smoke test config: tiny run
    cfg = ExperimentConfig(
        num_latent_states=4,
        num_rationale_templates=3,
        num_train_samples=256,
        num_eval_samples=64,
        num_shuffled_samples=64,
        num_epochs=3,
        batch_size=32,
        d_model=32,
        n_heads=2,
        n_layers=2,
        d_ff=64,
        results_path="results_comparison.csv",
    )
    df = run_experiment(cfg)
    print("\n=== RESULTS TABLE ===")
    try:
        print(df.to_markdown(index=False))
    except ImportError:
        print(df.to_string(index=False))
