"""
claim_consistency_experiment.py
================================
Minimal reproducible experiment for validating claim-consistency coupling
in a small decoder-only transformer.

Structure
---------
1. Config dataclass
2. Synthetic dataset generator
3. Small GPT-2-style decoder-only Transformer
4. Training loop (4 variants)
5. Evaluation: claim accuracy, counterfactual swap, shuffled-pairing control
6. Results table

Run as a script for a quick smoke test:
    python claim_consistency_experiment.py
"""

from __future__ import annotations

import copy
import math
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ---------------------------------------------------------------------------
# 1. Config
# ---------------------------------------------------------------------------

@dataclass
class ExperimentConfig:
    # Dataset
    num_latent_states: int = 8          # 8-16
    num_rationale_templates: int = 4    # 3-5 per latent state
    vocab_size: int = 128               # small synthetic vocabulary
    prompt_len: int = 4                 # tokens before rationale
    rationale_len: int = 8              # tokens in rationale span
    claim_len: int = 2                  # tokens for claim label
    num_train_samples: int = 512
    num_eval_samples: int = 128
    num_shuffled_samples: int = 128     # shuffled-pairing control
    seed: int = 42

    # Model
    d_model: int = 64
    n_heads: int = 4
    n_layers: int = 2
    d_ff: int = 128
    dropout: float = 0.1
    max_seq_len: int = 64

    # Training
    batch_size: int = 32
    num_epochs: int = 5                 # small for smoke test
    lr: float = 3e-4
    consistency_loss_weight: float = 0.5
    pooling_modes: Tuple[str, ...] = (
        "no_consistency_loss",
        "rationale_only",
        "full_sequence",
        "earlier_token_only",
        "claim_only_pooling",
    )

    # Hard overlapping-vocab dataset option
    hard_overlap_vocab: bool = False   # if True, use overlapping token pools across states
    overlap_fraction: float = 0.5      # fraction of each state's pool drawn from shared tokens

    # Misc
    device: str = "cpu"
    results_path: str = "results_comparison.csv"


# ---------------------------------------------------------------------------
# 2. Synthetic dataset
# ---------------------------------------------------------------------------

# We build a tiny vocabulary:
#   0          = PAD
#   1          = BOS
#   2..9       = prompt tokens (generic)
#   10..79     = rationale tokens  (10 per latent state * 8 latent states)
#   80..95     = claim label tokens (2 per latent state * 8 latent states)
#   96..127    = misc / separator tokens

_SEP_TOKEN = 96     # separator between prompt, rationale, claim sections
_BOS_TOKEN = 1
_PAD_TOKEN = 0

def _build_vocabulary(cfg: ExperimentConfig):
    """Return token ranges for each latent state.

    Dispatches to the overlapping-vocab builder when cfg.hard_overlap_vocab is True.
    """
    if cfg.hard_overlap_vocab:
        return _build_vocabulary_hard(cfg)
    return _build_vocabulary_easy(cfg)


def _build_vocabulary_easy(cfg: ExperimentConfig):
    """Original non-overlapping vocabulary: each state has its own exclusive token range."""
    rationale_base = 10
    claim_base = 80
    tokens_per_state_rationale = 10  # pool of tokens for rationale templates
    tokens_per_state_claim = 2       # deterministic claim tokens
    assert cfg.num_latent_states <= 8, "Extend vocab if > 8 latent states"
    vocab = {}
    for s in range(cfg.num_latent_states):
        rationale_tokens = list(range(
            rationale_base + s * tokens_per_state_rationale,
            rationale_base + s * tokens_per_state_rationale + tokens_per_state_rationale
        ))
        claim_tokens = list(range(
            claim_base + s * tokens_per_state_claim,
            claim_base + s * tokens_per_state_claim + tokens_per_state_claim
        ))
        templates = []
        rng = np.random.default_rng(s)  # deterministic per state
        for _ in range(cfg.num_rationale_templates):
            tmpl = rng.choice(rationale_tokens, size=cfg.rationale_len, replace=True).tolist()
            templates.append(tmpl)
        vocab[s] = {
            "rationale_templates": templates,
            "claim_tokens": claim_tokens,
        }
    return vocab


def _build_vocabulary_hard(cfg: ExperimentConfig):
    """Hard overlapping-vocabulary builder.

    Design
    ------
    Token layout (same claim tokens as easy, different rationale pool):
      0          = PAD
      1          = BOS
      2..9       = prompt tokens (generic)
      10..17     = shared_pool  (8 tokens shared by ALL states)
      18..49     = group pools  (8 tokens each, 4 groups for up to 8 states; 4*8=32 tokens)
      50..65     = local pools  (2 tokens per state * 8 states = 16 tokens)
      80..95     = claim label tokens (2 per latent state * 8 states, deterministic)
      96..127    = misc / separator tokens

    Overlap construction (per state s):
      - shared_pool  : tokens 10..17  (8 tokens shared by ALL states)
      - group_pool   : each state belongs to a group of 2 states; they share 8 tokens
                       group 0 (states 0,1): 18..25
                       group 1 (states 2,3): 26..33
                       group 2 (states 4,5): 34..41
                       group 3 (states 6,7): 42..49
      - local_pool   : 2 tokens unique to this state (50..65)

    Each template position is filled by sampling from:
      shared_pool    with probability overlap_fraction * 0.5
      group_pool     with probability overlap_fraction * 0.5
      local_pool     with probability (1 - overlap_fraction)

    At overlap_fraction=0.5 this yields ~50% of each template position from
    tokens also present in at least one other state, forcing the model to rely
    on token co-occurrence / combination patterns rather than single unique markers.

    No single token is perfectly diagnostic on its own:
    - shared tokens appear in all states.
    - group tokens appear in 2 states.
    - local tokens appear in only 1 state but constitute a minority.

    Claim tokens remain fully state-specific (2 per state, tokens 80-95).
    """
    assert cfg.num_latent_states <= 8, "Hard vocab supports up to 8 latent states"
    assert 0.0 <= cfg.overlap_fraction <= 1.0

    # ---- token pools ----
    SHARED_SIZE = 8
    GROUP_SIZE  = 8   # tokens per group (shared within a pair of states)
    LOCAL_SIZE  = 2   # tokens exclusively belonging to one state

    shared_pool = list(range(10, 10 + SHARED_SIZE))  # 10..17

    # group pools: pairs of adjacent states share a GROUP_SIZE-token block
    group_base = 10 + SHARED_SIZE  # = 18
    num_groups = (cfg.num_latent_states + 1) // 2   # ceil(N/2)
    group_pools = [
        list(range(group_base + g * GROUP_SIZE, group_base + g * GROUP_SIZE + GROUP_SIZE))
        for g in range(num_groups)
    ]
    # group_pools end at:  18 + num_groups * 8 - 1  = 18 + 4*8 - 1 = 49  (for 8 states)

    # local pools: 2 exclusive tokens per state
    local_base = group_base + num_groups * GROUP_SIZE  # = 50 (for 8 states)
    local_pools = [
        list(range(local_base + s * LOCAL_SIZE, local_base + s * LOCAL_SIZE + LOCAL_SIZE))
        for s in range(cfg.num_latent_states)
    ]
    # local_pools end at:  50 + 8*2 - 1 = 65   (well below claim base 80)

    # claim tokens: 2 per state, starting at 80
    claim_base = 80
    tokens_per_state_claim = 2

    vocab = {}
    for s in range(cfg.num_latent_states):
        claim_tokens = list(range(
            claim_base + s * tokens_per_state_claim,
            claim_base + s * tokens_per_state_claim + tokens_per_state_claim
        ))
        g = s // 2  # which group this state belongs to
        gpool = group_pools[g]
        lpool = local_pools[s]

        # Sampling probabilities per template position:
        p_shared = cfg.overlap_fraction * 0.5
        p_group  = cfg.overlap_fraction * 0.5
        p_local  = 1.0 - cfg.overlap_fraction

        combined_tokens = shared_pool + gpool + lpool
        weights = (
            [p_shared / len(shared_pool)] * len(shared_pool)
            + [p_group  / len(gpool)]      * len(gpool)
            + [p_local  / len(lpool)]      * len(lpool)
        )
        total_w = sum(weights)
        weights = [w / total_w for w in weights]

        templates = []
        for t in range(cfg.num_rationale_templates):
            tmpl_rng = np.random.default_rng(100 + s * 100 + t)
            tmpl = tmpl_rng.choice(
                combined_tokens,
                size=cfg.rationale_len,
                replace=True,
                p=weights,
            ).tolist()
            templates.append(tmpl)

        vocab[s] = {
            "rationale_templates": templates,
            "claim_tokens": claim_tokens,
        }
    return vocab


def _make_prompt_tokens(cfg: ExperimentConfig, rng: np.random.Generator) -> List[int]:
    return rng.integers(2, 10, size=cfg.prompt_len).tolist()


def make_sample(
    cfg: ExperimentConfig,
    vocab: dict,
    latent_state: int,
    rng: np.random.Generator,
    shuffled: bool = False,
    swap_rationale_from: Optional[int] = None,
) -> dict:
    """
    Build one sample.

    Layout: [BOS] [prompt...] [SEP] [rationale...] [SEP] [claim...]

    Returns a dict with:
        token_ids        : (seq_len,) int64
        targets          : (seq_len,) int64  next-token shifted
        rationale_mask   : (seq_len,) bool   True at rationale positions
        full_seq_mask    : (seq_len,) bool   True everywhere (except PAD)
        earlier_tok_mask : (seq_len,) bool   True at prompt+rationale positions
        claim_mask       : (seq_len,) bool   True at claim token positions only
        latent_state     : int
        claim_label      : int  (index of first claim token in vocab)
        is_shuffled      : bool
    """
    # choose rationale template
    if swap_rationale_from is not None:
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[swap_rationale_from]["rationale_templates"][rat_tmpl_idx]
    elif shuffled:
        # pick a random OTHER state's rationale
        other = rng.integers(0, cfg.num_latent_states)
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[other]["rationale_templates"][rat_tmpl_idx]
    else:
        rat_tmpl_idx = rng.integers(0, cfg.num_rationale_templates)
        rationale_tokens = vocab[latent_state]["rationale_templates"][rat_tmpl_idx]

    claim_tokens = vocab[latent_state]["claim_tokens"]
    prompt_tokens = _make_prompt_tokens(cfg, rng)

    # full sequence
    seq = [_BOS_TOKEN] + prompt_tokens + [_SEP_TOKEN] + rationale_tokens + [_SEP_TOKEN] + claim_tokens
    seq_len = len(seq)  # = 1 + prompt_len + 1 + rationale_len + 1 + claim_len

    token_ids = torch.tensor(seq, dtype=torch.long)
    # next-token targets: shift by 1, last target = PAD
    targets = torch.cat([token_ids[1:], torch.tensor([_PAD_TOKEN], dtype=torch.long)])

    # masks (positional)
    prompt_start = 1                         # after BOS
    prompt_end = 1 + cfg.prompt_len          # exclusive
    rat_start = prompt_end + 1               # after first SEP
    rat_end = rat_start + cfg.rationale_len  # exclusive
    claim_start = rat_end + 1                # after second SEP

    rationale_mask = torch.zeros(seq_len, dtype=torch.bool)
    rationale_mask[rat_start:rat_end] = True

    full_seq_mask = torch.ones(seq_len, dtype=torch.bool)

    earlier_tok_mask = torch.zeros(seq_len, dtype=torch.bool)
    earlier_tok_mask[prompt_start:rat_end] = True  # prompt + rationale, not claim

    claim_mask = torch.zeros(seq_len, dtype=torch.bool)
    claim_mask[claim_start:claim_start + cfg.claim_len] = True  # only claim token positions

    return {
        "token_ids": token_ids,
        "targets": targets,
        "rationale_mask": rationale_mask,
        "full_seq_mask": full_seq_mask,
        "earlier_tok_mask": earlier_tok_mask,
        "claim_mask": claim_mask,
        "latent_state": latent_state,
        "claim_label": latent_state,  # one label per latent state (index)
        "is_shuffled": int(shuffled),
    }


class ClaimConsistencyDataset(Dataset):
    def __init__(
        self,
        cfg: ExperimentConfig,
        vocab: dict,
        n_samples: int,
        shuffled: bool = False,
        seed_offset: int = 0,
    ):
        self.samples = []
        rng = np.random.default_rng(cfg.seed + seed_offset)
        for i in range(n_samples):
            ls = int(rng.integers(0, cfg.num_latent_states))
            s = make_sample(cfg, vocab, ls, rng, shuffled=shuffled)
            self.samples.append(s)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def collate_fn(batch):
    keys = batch[0].keys()
    out = {}
    for k in keys:
        v0 = batch[0][k]
        if isinstance(v0, torch.Tensor):
            out[k] = torch.stack([b[k] for b in batch])
        else:
            out[k] = torch.tensor([b[k] for b in batch], dtype=torch.long)
    return out


# ---------------------------------------------------------------------------
# 3. Model
# ---------------------------------------------------------------------------

class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: ExperimentConfig):
        super().__init__()
        assert cfg.d_model % cfg.n_heads == 0
        self.n_heads = cfg.n_heads
        self.head_dim = cfg.d_model // cfg.n_heads
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # each (B, T, H, D)
        q = q.transpose(1, 2)    # (B, H, T, D)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        scale = math.sqrt(self.head_dim)
        att = (q @ k.transpose(-2, -1)) / scale  # (B, H, T, T)
        # causal mask
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
        att = att.masked_fill(mask == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.drop(att)
        y = (att @ v).transpose(1, 2).reshape(B, T, C)
        return self.proj(y)


class TransformerBlock(nn.Module):
    def __init__(self, cfg: ExperimentConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.ff = nn.Sequential(
            nn.Linear(cfg.d_model, cfg.d_ff),
            nn.GELU(),
            nn.Linear(cfg.d_ff, cfg.d_model),
            nn.Dropout(cfg.dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        patch_callback: Optional[callable] = None,
    ) -> torch.Tensor:
        """Standard block forward.

        Args:
            x: (B, T, d_model) hidden states entering this block.
            patch_callback: Optional callable(x_after_block) -> x_patched.
                If provided, called *after* the residual additions so that
                callers can overwrite specific positions before downstream blocks.
        """
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        if patch_callback is not None:
            x = patch_callback(x)
        return x


class ClaimConsistencyTransformer(nn.Module):
    """
    Decoder-only transformer with an optional consistency classification head.

    pooling_mode:
        "no_consistency_loss"  - LM loss only, no consistency head used in training
        "rationale_only"       - pool hidden states over rationale mask
        "full_sequence"        - pool over entire non-pad sequence
        "earlier_token_only"   - pool over prompt + rationale positions only
        "claim_only_pooling"   - pool over claim token positions only (negative control)
    """
    def __init__(self, cfg: ExperimentConfig, pooling_mode: str = "no_consistency_loss"):
        super().__init__()
        self.cfg = cfg
        self.pooling_mode = pooling_mode

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

        # consistency head (always built, selectively used)
        self.consistency_head = nn.Linear(cfg.d_model, cfg.num_latent_states)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)

    def forward(
        self,
        token_ids: torch.Tensor,          # (B, T)
        targets: Optional[torch.Tensor],  # (B, T)
        rationale_mask: torch.Tensor,     # (B, T)
        full_seq_mask: torch.Tensor,      # (B, T)
        earlier_tok_mask: torch.Tensor,   # (B, T)
        claim_labels: torch.Tensor,       # (B,)
        consistency_weight: float = 0.5,
        claim_mask: Optional[torch.Tensor] = None,  # (B, T) claim token positions
    ) -> Tuple[torch.Tensor, dict]:
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))

        for block in self.blocks:
            x = block(x)  # no patch_callback in standard forward
        x = self.ln_f(x)  # (B, T, d_model)

        # LM loss
        logits = self.lm_head(x)  # (B, T, V)
        lm_loss = torch.tensor(0.0)
        if targets is not None:
            lm_loss = F.cross_entropy(
                logits.reshape(-1, self.cfg.vocab_size),
                targets.reshape(-1),
                ignore_index=_PAD_TOKEN,
            )

        # Consistency loss
        cons_loss = torch.tensor(0.0)
        cons_logits = None
        if self.pooling_mode != "no_consistency_loss":
            pool_mask = self._get_pool_mask(
                rationale_mask, full_seq_mask, earlier_tok_mask, claim_mask
            )  # (B, T)
            # mean pool over selected positions
            denom = pool_mask.float().sum(dim=1, keepdim=True).clamp(min=1)
            pooled = (x * pool_mask.float().unsqueeze(-1)).sum(dim=1) / denom  # (B, d)
            cons_logits = self.consistency_head(pooled)  # (B, num_latent_states)
            cons_loss = F.cross_entropy(cons_logits, claim_labels)

        total_loss = lm_loss + consistency_weight * cons_loss

        return total_loss, {
            "lm_loss": lm_loss.item(),
            "cons_loss": cons_loss.item(),
            "total_loss": total_loss.item(),
            "hidden_states": x,
            "lm_logits": logits,
            "cons_logits": cons_logits,
        }

    def _get_pool_mask(self, rationale_mask, full_seq_mask, earlier_tok_mask, claim_mask=None):
        if self.pooling_mode == "rationale_only":
            return rationale_mask
        elif self.pooling_mode == "full_sequence":
            return full_seq_mask
        elif self.pooling_mode == "earlier_token_only":
            return earlier_tok_mask
        elif self.pooling_mode == "claim_only_pooling":
            # Pool only over claim token positions for consistency training.
            # claim_mask must be provided when this mode is active.
            if claim_mask is None:
                raise ValueError(
                    "claim_mask must be provided when pooling_mode='claim_only_pooling'"
                )
            return claim_mask
        else:
            return full_seq_mask  # fallback

    @torch.no_grad()
    def pool_hidden_rationale(self, token_ids, rationale_mask):
        """Pool hidden states over rationale mask. Used for eval regardless of pooling_mode."""
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        denom = rationale_mask.float().sum(dim=1, keepdim=True).clamp(min=1)
        pooled = (x * rationale_mask.float().unsqueeze(-1)).sum(dim=1) / denom
        return self.consistency_head(pooled)  # (B, num_latent_states)

    @torch.no_grad()
    def forward_blocks_with_cache(
        self,
        token_ids: torch.Tensor,   # (B, T)
        rationale_mask: torch.Tensor,  # (B, T) bool
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Run a full forward pass and cache, after every transformer block,
        the hidden-state slice at rationale positions.

        Returns:
            lm_logits: (B, T, V)  – final LM logits (for claim-token position)
            block_rationale_hs: list[n_layers] of (B, rationale_len, d_model)
                Hidden states at rationale positions right after each block
                (before the final LayerNorm).
        """
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))

        block_rationale_hs: List[torch.Tensor] = []
        for block in self.blocks:
            x = block(x)
            # Collect hidden states at rationale positions for this block
            # rationale_mask: (B, T) bool → extract per-sample rationale slice
            # We store the full (B, T, d) but masked; callers unpack by mask.
            block_rationale_hs.append(x.clone())  # (B, T, d_model)

        x = self.ln_f(x)
        lm_logits = self.lm_head(x)  # (B, T, V)
        return lm_logits, block_rationale_hs

    @torch.no_grad()
    def forward_blocks_with_patch(
        self,
        token_ids: torch.Tensor,        # (B, T) – swapped-rationale sequence
        rationale_mask: torch.Tensor,   # (B, T) bool
        patch_layer: int,               # which block index (0-based) to patch AFTER
        cached_hs: torch.Tensor,        # (B, T, d_model) full hidden state from original run
    ) -> torch.Tensor:
        """Run forward pass with swapped-rationale tokens, but after `patch_layer`
        replace the hidden states at rationale positions with those from `cached_hs`.

        Because the transformer is causal, this injection propagates forward into
        all subsequent blocks and into the claim-token prediction.

        Args:
            token_ids: Sequence with the SWAPPED rationale tokens.
            rationale_mask: Boolean mask of rationale positions (same layout as original).
            patch_layer: Index of the block after which to inject.  0 = after block 0, etc.
            cached_hs: Full (B, T, d_model) tensor from forward_blocks_with_cache at
                       patch_layer – i.e. block_rationale_hs[patch_layer].

        Returns:
            lm_logits: (B, T, V) final logits after patching.
        """
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.drop(self.tok_emb(token_ids) + self.pos_emb(pos))

        for layer_idx, block in enumerate(self.blocks):
            if layer_idx == patch_layer:
                # Build a patch_callback that replaces rationale positions
                # with the cached original hidden states.
                rat_mask = rationale_mask  # (B, T)

                def _make_patch_callback(cached, mask):
                    def _cb(h):
                        # h: (B, T, d_model); overwrite rationale positions
                        m = mask.unsqueeze(-1).float()  # (B, T, 1)
                        return h * (1.0 - m) + cached * m
                    return _cb

                x = block(x, patch_callback=_make_patch_callback(cached_hs, rat_mask))
            else:
                x = block(x)

        x = self.ln_f(x)
        return self.lm_head(x)  # (B, T, V)

    @torch.no_grad()
    def generate_claim(
        self,
        token_ids: torch.Tensor,  # (B, prefix_len) = prompt+rationale portion
        max_new: int = 4,
    ) -> torch.Tensor:
        """Greedy generation of claim tokens after prompt+rationale."""
        generated = token_ids.clone()
        for _ in range(max_new):
            B, T = generated.shape
            pos = torch.arange(T, device=generated.device).unsqueeze(0)
            x = self.drop(self.tok_emb(generated) + self.pos_emb(pos))
            for block in self.blocks:
                x = block(x)
            x = self.ln_f(x)
            logits = self.lm_head(x[:, -1, :])  # (B, V)
            next_tok = logits.argmax(dim=-1, keepdim=True)  # (B, 1)
            generated = torch.cat([generated, next_tok], dim=1)
        return generated[:, token_ids.shape[1]:]  # return only new tokens


# ---------------------------------------------------------------------------
# 4. Training
# ---------------------------------------------------------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def train_one_variant(
    cfg: ExperimentConfig,
    pooling_mode: str,
    train_loader: DataLoader,
    eval_loader: DataLoader,
    shuffled_loader: DataLoader,
    vocab: dict,
) -> Tuple[ClaimConsistencyTransformer, List[dict]]:
    set_seed(cfg.seed)
    model = ClaimConsistencyTransformer(cfg, pooling_mode=pooling_mode).to(cfg.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    history = []

    for epoch in range(cfg.num_epochs):
        model.train()
        total_lm, total_cons, n_batches = 0.0, 0.0, 0
        for batch in train_loader:
            batch = {k: v.to(cfg.device) for k, v in batch.items()}
            optimizer.zero_grad()
            loss, info = model(
                batch["token_ids"],
                batch["targets"],
                batch["rationale_mask"],
                batch["full_seq_mask"],
                batch["earlier_tok_mask"],
                batch["claim_label"],
                consistency_weight=cfg.consistency_loss_weight,
                claim_mask=batch.get("claim_mask"),
            )
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_lm += info["lm_loss"]
            total_cons += info["cons_loss"]
            n_batches += 1

        history.append({
            "epoch": epoch + 1,
            "lm_loss": total_lm / n_batches,
            "cons_loss": total_cons / n_batches,
        })

    return model, history


# ---------------------------------------------------------------------------
# 5. Evaluation helpers
# ---------------------------------------------------------------------------

def _seq_prefix(cfg, sample):
    """Return token_ids up to and including rationale (before claim SEP)."""
    # Layout: BOS(1) + prompt(4) + SEP(1) + rationale(8) + SEP(1) + claim(2)
    # prefix = everything up to but not including first claim token
    prefix_len = 1 + cfg.prompt_len + 1 + cfg.rationale_len + 1  # = 15
    return sample["token_ids"][:prefix_len]


def evaluate_claim_accuracy_generation(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    dataset: ClaimConsistencyDataset,
    vocab: dict,
) -> float:
    """Generate claims greedily and check if first generated token matches expected claim token."""
    model.eval()
    correct = 0
    total = 0
    # build claim_token_map: latent_state -> first claim token id
    claim_first_token = {s: vocab[s]["claim_tokens"][0] for s in range(cfg.num_latent_states)}

    for sample in dataset.samples:
        prefix = _seq_prefix(cfg, sample).unsqueeze(0).to(cfg.device)
        generated = model.generate_claim(prefix, max_new=cfg.claim_len)
        pred_first = generated[0, 0].item()
        expected = claim_first_token[sample["claim_label"]]
        if pred_first == expected:
            correct += 1
        total += 1
    return correct / total if total > 0 else 0.0


def evaluate_claim_accuracy_classifier(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    dataset: ClaimConsistencyDataset,
) -> float:
    """Classify claim from rationale-only pooled hidden states."""
    model.eval()
    correct = 0
    total = 0
    for sample in dataset.samples:
        tids = sample["token_ids"].unsqueeze(0).to(cfg.device)
        rmask = sample["rationale_mask"].unsqueeze(0).to(cfg.device)
        logits = model.pool_hidden_rationale(tids, rmask)
        pred = logits.argmax(dim=-1).item()
        if pred == sample["claim_label"]:
            correct += 1
        total += 1
    return correct / total if total > 0 else 0.0


def evaluate_counterfactual_swap(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    vocab: dict,
    n_samples: int = 64,
    seed_offset: int = 999,
) -> dict:
    """
    Counterfactual rationale swap test.

    For each sample:
      - original_latent: latent state that determines the claim
      - swapped_rationale: taken from a DIFFERENT latent state
      - Build sequence with original_latent's claim but swapped rationale
      - Measure:
          gen_follows_swap: generated claim matches swapped_rationale latent (vs original)
          cls_follows_swap: classifier (rationale-pool) follows swapped rationale
    """
    model.eval()
    rng = np.random.default_rng(cfg.seed + seed_offset)
    gen_follows_swap = 0
    cls_follows_swap = 0
    gen_follows_orig = 0
    cls_follows_orig = 0
    total = 0

    claim_first_token = {s: vocab[s]["claim_tokens"][0] for s in range(cfg.num_latent_states)}

    for _ in range(n_samples):
        orig_state = int(rng.integers(0, cfg.num_latent_states))
        swap_state = int(rng.integers(0, cfg.num_latent_states))
        while swap_state == orig_state:
            swap_state = int(rng.integers(0, cfg.num_latent_states))

        sample = make_sample(cfg, vocab, orig_state, rng, swap_rationale_from=swap_state)

        prefix = _seq_prefix(cfg, sample).unsqueeze(0).to(cfg.device)
        tids = sample["token_ids"].unsqueeze(0).to(cfg.device)
        rmask = sample["rationale_mask"].unsqueeze(0).to(cfg.device)

        # generation
        generated = model.generate_claim(prefix, max_new=cfg.claim_len)
        pred_gen = generated[0, 0].item()
        if pred_gen == claim_first_token[swap_state]:
            gen_follows_swap += 1
        if pred_gen == claim_first_token[orig_state]:
            gen_follows_orig += 1

        # classifier
        cls_logits = model.pool_hidden_rationale(tids, rmask)
        pred_cls = cls_logits.argmax(dim=-1).item()
        if pred_cls == swap_state:
            cls_follows_swap += 1
        if pred_cls == orig_state:
            cls_follows_orig += 1

        total += 1

    return {
        "gen_follows_swap_rate": gen_follows_swap / total,
        "gen_follows_orig_rate": gen_follows_orig / total,
        "cls_follows_swap_rate": cls_follows_swap / total,
        "cls_follows_orig_rate": cls_follows_orig / total,
        "n_samples": total,
    }


def evaluate_shuffled_pairing(
    cfg: ExperimentConfig,
    model: ClaimConsistencyTransformer,
    shuffled_dataset: ClaimConsistencyDataset,
    vocab: dict,
) -> dict:
    """
    Evaluate on shuffled-pairing control (rationale-claim mismatch).
    Returns generation accuracy and classifier accuracy on shuffled data.
    Gen accuracy is expected to be ~1/num_latent_states if the model is confused.
    """
    gen_acc = evaluate_claim_accuracy_generation(cfg, model, shuffled_dataset, vocab=vocab)
    # Note: for generation we evaluate against the CLAIM label (not the shuffled rationale)
    # so a drop indicates the model was relying on rationale
    cls_acc = evaluate_claim_accuracy_classifier(cfg, model, shuffled_dataset)
    return {
        "shuffled_gen_acc": gen_acc,
        "shuffled_cls_acc": cls_acc,
    }


# ---------------------------------------------------------------------------
# 6. Hidden-state intervention / causal patching evaluation
# ---------------------------------------------------------------------------

def evaluate_hidden_state_intervention(
    cfg: Optional[ExperimentConfig] = None,
    models: Optional[Dict[str, "ClaimConsistencyTransformer"]] = None,
    n_samples: int = 64,
    seed_offset: int = 1337,
) -> pd.DataFrame:
    """
    Causal patching / hidden-state intervention evaluation.

    Protocol (per sample, per layer, per variant)
    ---------------------------------------------
    1. Build an "original" sample with latent state A and its matching rationale.
    2. Build a "swapped" sample: same latent-state A claim tokens but rationale
       tokens drawn from a *different* latent state B.
    3. Run forward_blocks_with_cache() on the original sequence  → cache
       post-block hidden states at rationale positions for every layer.
    4. For each transformer block i (0 .. n_layers-1):
       a. Run forward_blocks_with_patch() on the swapped-rationale sequence,
          injecting the cached original hidden states at rationale positions
          *after* block i.  All subsequent blocks see the patched states.
       b. Read logits at the SEP position that immediately precedes the claim
          (i.e. the last position of the prefix, from which the model predicts
          the first claim token).  The argmax is the "patched prediction".
       c. intervention_follows_original_hs = 1 if patched prediction equals
          the first claim token of the ORIGINAL latent state A.
       d. intervention_follows_swapped_tokens = 1 if patched prediction equals
          the first claim token of the SWAPPED latent state B.

    Claim-token identification
    --------------------------
    The prefix fed to the model is:
        [BOS] [prompt...] [SEP] [rationale...] [SEP]
    Length = 1 + prompt_len + 1 + rationale_len + 1  (= prefix_len).
    We use the logit at position prefix_len-1 (the final SEP token) as the
    prediction of the first claim token, consistent with greedy next-token
    generation used elsewhere in the codebase.

    If `models` is None the function trains all 4 variants from scratch using
    the smoke-test-sized config (or the provided cfg).  Pre-trained models can
    be passed as a dict {pooling_mode: model}.

    Returns a DataFrame with columns:
        variant, layer, patch_layer_id,
        intervention_follows_original_hs, intervention_follows_swapped_tokens,
        n_samples
    """
    if cfg is None:
        cfg = ExperimentConfig()

    set_seed(cfg.seed)
    vocab = _build_vocabulary(cfg)

    # ----- train / reuse models -----
    if models is None:
        print("[hidden_state_intervention] Training all 4 variants (no saved models provided) ...")
        train_ds = ClaimConsistencyDataset(
            cfg, vocab, cfg.num_train_samples, shuffled=False, seed_offset=0
        )
        eval_ds  = ClaimConsistencyDataset(
            cfg, vocab, cfg.num_eval_samples,  shuffled=False, seed_offset=100
        )
        shuf_ds  = ClaimConsistencyDataset(
            cfg, vocab, cfg.num_shuffled_samples, shuffled=True, seed_offset=200
        )
        train_loader = DataLoader(
            train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn
        )
        eval_loader  = DataLoader(
            eval_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn
        )
        shuf_loader  = DataLoader(
            shuf_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn
        )
        models = {}
        for pooling_mode in cfg.pooling_modes:
            print(f"  Training variant={pooling_mode} ...")
            model, _ = train_one_variant(
                cfg, pooling_mode, train_loader, eval_loader, shuf_loader, vocab
            )
            models[pooling_mode] = model

    # ----- claim token map -----
    claim_first_token = {
        s: vocab[s]["claim_tokens"][0] for s in range(cfg.num_latent_states)
    }

    # prefix length (positions 0..prefix_len-1 inclusive)
    # Layout: BOS + prompt + SEP + rationale + SEP
    prefix_len = 1 + cfg.prompt_len + 1 + cfg.rationale_len + 1
    # The logit at position (prefix_len - 1) predicts the first claim token.
    claim_pred_pos = prefix_len - 1

    rows = []

    for pooling_mode, model in models.items():
        model.eval()
        print(f"\n[hidden_state_intervention] Evaluating variant={pooling_mode} ...")

        # Counters: shape (n_layers,)
        follows_orig   = [0] * cfg.n_layers
        follows_swapped = [0] * cfg.n_layers
        total_counted   = [0] * cfg.n_layers

        rng = np.random.default_rng(cfg.seed + seed_offset)

        for _ in range(n_samples):
            # Sample original and swap latent states
            orig_state = int(rng.integers(0, cfg.num_latent_states))
            swap_state = int(rng.integers(0, cfg.num_latent_states))
            while swap_state == orig_state:
                swap_state = int(rng.integers(0, cfg.num_latent_states))

            # Build original sample (rationale matches latent state orig_state)
            orig_sample = make_sample(cfg, vocab, orig_state, rng)
            # Build swapped sample: claim label = orig_state, rationale from swap_state
            swap_sample = make_sample(
                cfg, vocab, orig_state, rng, swap_rationale_from=swap_state
            )

            # Full sequences (we need all tokens for the caching pass)
            orig_ids  = orig_sample["token_ids"].unsqueeze(0).to(cfg.device)  # (1, T)
            swap_ids  = swap_sample["token_ids"].unsqueeze(0).to(cfg.device)  # (1, T)
            rat_mask  = orig_sample["rationale_mask"].unsqueeze(0).to(cfg.device)  # (1, T)

            # ---- Step 3: cache original hidden states at every block ----
            _orig_logits, cached_hs_list = model.forward_blocks_with_cache(
                orig_ids, rat_mask
            )  # cached_hs_list: list[n_layers] of (1, T, d_model)

            # ---- Step 4: patch each layer individually ----
            for layer_idx in range(cfg.n_layers):
                patched_logits = model.forward_blocks_with_patch(
                    swap_ids,
                    rat_mask,
                    patch_layer=layer_idx,
                    cached_hs=cached_hs_list[layer_idx],
                )  # (1, T, V)

                # Logit at claim prediction position (last SEP)
                pred_logit = patched_logits[0, claim_pred_pos, :]  # (V,)
                pred_tok = pred_logit.argmax().item()

                if pred_tok == claim_first_token[orig_state]:
                    follows_orig[layer_idx] += 1
                if pred_tok == claim_first_token[swap_state]:
                    follows_swapped[layer_idx] += 1
                total_counted[layer_idx] += 1

        # Collect rows for this variant
        for layer_idx in range(cfg.n_layers):
            layer_id = f"block_{layer_idx}"
            n = total_counted[layer_idx]
            rows.append({
                "variant": pooling_mode,
                "layer": layer_idx,
                "patch_layer_id": layer_id,
                "intervention_follows_original_hs": round(follows_orig[layer_idx] / n, 4) if n else None,
                "intervention_follows_swapped_tokens": round(follows_swapped[layer_idx] / n, 4) if n else None,
                "n_samples": n,
            })

    df = pd.DataFrame(rows)
    return df


# ---------------------------------------------------------------------------
# 7. Full experiment runner
# ---------------------------------------------------------------------------

def run_experiment(cfg: Optional[ExperimentConfig] = None) -> pd.DataFrame:
    if cfg is None:
        cfg = ExperimentConfig()

    set_seed(cfg.seed)
    print(f"[INFO] Building vocabulary for {cfg.num_latent_states} latent states ...")
    vocab = _build_vocabulary(cfg)

    print("[INFO] Building datasets ...")
    train_ds = ClaimConsistencyDataset(cfg, vocab, cfg.num_train_samples, shuffled=False, seed_offset=0)
    eval_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_eval_samples,  shuffled=False, seed_offset=100)
    shuf_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_shuffled_samples, shuffled=True, seed_offset=200)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    eval_loader  = DataLoader(eval_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    shuf_loader  = DataLoader(shuf_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    results = []

    for pooling_mode in cfg.pooling_modes:
        print(f"\n{'='*60}")
        print(f"[TRAIN] variant={pooling_mode}")
        model, history = train_one_variant(
            cfg, pooling_mode, train_loader, eval_loader, shuf_loader, vocab
        )
        final_lm = history[-1]["lm_loss"]
        final_cons = history[-1]["cons_loss"]
        print(f"  Final LM loss={final_lm:.4f}  Cons loss={final_cons:.4f}")

        print(f"[EVAL]  variant={pooling_mode}")
        gen_acc = evaluate_claim_accuracy_generation(cfg, model, eval_ds, vocab)
        cls_acc = evaluate_claim_accuracy_classifier(cfg, model, eval_ds)
        cfact   = evaluate_counterfactual_swap(cfg, model, vocab, n_samples=64)
        shuf    = evaluate_shuffled_pairing(cfg, model, shuf_ds, vocab)

        row = {
            "variant": pooling_mode,
            "final_lm_loss": round(final_lm, 4),
            "final_cons_loss": round(final_cons, 4),
            "gen_claim_acc": round(gen_acc, 4),
            "cls_claim_acc (rationale_pool)": round(cls_acc, 4),
            "cfact_gen_follows_swap": round(cfact["gen_follows_swap_rate"], 4),
            "cfact_gen_follows_orig": round(cfact["gen_follows_orig_rate"], 4),
            "cfact_cls_follows_swap": round(cfact["cls_follows_swap_rate"], 4),
            "cfact_cls_follows_orig": round(cfact["cls_follows_orig_rate"], 4),
            "shuffled_gen_acc": round(shuf["shuffled_gen_acc"], 4),
            "shuffled_cls_acc": round(shuf["shuffled_cls_acc"], 4),
        }
        results.append(row)
        print("  " + str({k: v for k, v in row.items() if k != "variant"}))

    df = pd.DataFrame(results)
    df.to_csv(cfg.results_path, index=False)
    print(f"\n[DONE] Results saved to {cfg.results_path}")
    return df


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    # Smoke test config: tiny run
    cfg = ExperimentConfig(
        num_latent_states=4,
        num_rationale_templates=3,
        num_train_samples=256,
        num_eval_samples=64,
        num_shuffled_samples=64,
        num_epochs=3,
        batch_size=32,
        d_model=32,
        n_heads=2,
        n_layers=2,
        d_ff=64,
        results_path="results_comparison.csv",
    )
    df = run_experiment(cfg)
    print("\n=== RESULTS TABLE ===")
    try:
        print(df.to_markdown(index=False))
    except ImportError:
        print(df.to_string(index=False))

    print("\n=== HIDDEN-STATE INTERVENTION EVAL (smoke test) ===")
    df_hs = evaluate_hidden_state_intervention(cfg=cfg, n_samples=32)
    try:
        print(df_hs.to_markdown(index=False))
    except ImportError:
        print(df_hs.to_string(index=False))
