"""
model.py — GPT-2-style causal Transformer for the consistency loss experiment.

Architecture:
  • Causal (autoregressive) self-attention with causal masking.
  • Explanation tokens cannot attend to claim tokens (enforced via attention bias).
  • Output head: standard LM head.
  • Consistency head: linear classifier over pooled explanation hidden states.

Full-size GPT-2 config is documented; a tiny config is used for smoke runs.

Stronger ablation extensions (added for the v2 ablation ladder):
  • build_no_claim_to_claim_mask: blocks claim queries from attending previous claim keys
  • build_claims_from_explanation_only_mask: blocks claim queries from attending code AND
    other claim keys — claims attend only to explanation tokens (+ current pos under causal)
  • SurfaceBottleneckHeads: linear classifiers over mean-pooled *softmax* distributions at
    explanation target positions; gradients flow into LM logits (no detach)
  • compute_lm_loss_masked: LM loss with optional per-token mask to disable loss on
    mismatched explanation positions while preserving claim-token LM loss
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple


# ──────────────────────────────────────────────────────────────────────────────
# 1. Config
# ──────────────────────────────────────────────────────────────────────────────

@dataclass
class TransformerConfig:
    vocab_size:      int   = 1000     # set after tokenizer is built
    max_seq_len:     int   = 256
    d_model:         int   = 128      # hidden dim
    n_heads:         int   = 4
    n_layers:        int   = 4
    d_ff:            int   = 512      # feedforward dim
    dropout:         float = 0.1
    # Claim head dims
    n_time_classes:  int   = 3
    n_space_classes: int   = 3
    n_correct_classes: int = 2
    pad_id:          int   = 0

    # ── Pre-set configs ───────────────────────────────────────────────────────

    @classmethod
    def smoke(cls, vocab_size: int, pad_id: int = 0) -> "TransformerConfig":
        """Tiny config for fast smoke testing (~1M params)."""
        return cls(
            vocab_size=vocab_size, pad_id=pad_id,
            max_seq_len=128, d_model=64, n_heads=2, n_layers=2, d_ff=128,
        )

    @classmethod
    def small(cls, vocab_size: int, pad_id: int = 0) -> "TransformerConfig":
        """Small config suitable for CPU training (~10M params)."""
        return cls(
            vocab_size=vocab_size, pad_id=pad_id,
            max_seq_len=256, d_model=256, n_heads=4, n_layers=4, d_ff=1024,
        )

    @classmethod
    def gpt2_small(cls, vocab_size: int, pad_id: int = 0) -> "TransformerConfig":
        """
        GPT-2-compatible config (~117M params).
        Requires GPU with sufficient VRAM (≥8GB).
        Note: This experiment uses a custom Transformer with identical causal
        masking semantics to GPT-2. GPT-2 pretrained weights are NOT loaded
        (would require the `transformers` library). This is a from-scratch
        GPT-2-style architecture.
        """
        return cls(
            vocab_size=vocab_size, pad_id=pad_id,
            max_seq_len=1024, d_model=768, n_heads=12, n_layers=12, d_ff=3072,
        )


# ──────────────────────────────────────────────────────────────────────────────
# 2. Building blocks
# ──────────────────────────────────────────────────────────────────────────────

class MultiHeadCausalAttention(nn.Module):
    """
    Multi-head causal self-attention.
    Supports an optional additive attention bias for enforcing
    "explanation tokens cannot attend to claim tokens".
    """

    def __init__(self, cfg: TransformerConfig):
        super().__init__()
        assert cfg.d_model % cfg.n_heads == 0
        self.n_heads = cfg.n_heads
        self.d_head  = cfg.d_model // cfg.n_heads
        self.scale   = math.sqrt(self.d_head)

        self.qkv  = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.out  = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

        # Register causal mask buffer (will be resized on first call if needed)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len)).bool()
        )

    def forward(
        self,
        x: torch.Tensor,                       # (B, T, d_model)
        extra_mask: Optional[torch.Tensor] = None  # (B, T, T) additive float mask
    ) -> torch.Tensor:
        B, T, D = x.shape
        qkv = self.qkv(x).split(D, dim=-1)
        q, k, v = [t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
                   for t in qkv]  # (B, H, T, d_head)

        scores = (q @ k.transpose(-2, -1)) / self.scale  # (B, H, T, T)

        # Causal mask: token i cannot attend to j > i
        causal = self.causal_mask[:T, :T]  # (T, T)
        scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float('-inf'))

        # Extra mask (e.g. block explanation→claim attention)
        if extra_mask is not None:
            scores = scores + extra_mask.unsqueeze(1)  # broadcast over heads

        attn = F.softmax(scores, dim=-1)
        attn = self.drop(attn)
        out  = (attn @ v).transpose(1, 2).contiguous().view(B, T, D)
        return self.out(out)


class FeedForward(nn.Module):
    def __init__(self, cfg: TransformerConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(cfg.d_model, cfg.d_ff),
            nn.GELU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(cfg.d_ff, cfg.d_model),
            nn.Dropout(cfg.dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, cfg: TransformerConfig):
        super().__init__()
        self.ln1  = nn.LayerNorm(cfg.d_model)
        self.attn = MultiHeadCausalAttention(cfg)
        self.ln2  = nn.LayerNorm(cfg.d_model)
        self.ff   = FeedForward(cfg)

    def forward(self, x, extra_mask=None):
        x = x + self.attn(self.ln1(x), extra_mask=extra_mask)
        x = x + self.ff(self.ln2(x))
        return x



# ──────────────────────────────────────────────────────────────────────────────
# 2b. Surface Bottleneck Heads
# ──────────────────────────────────────────────────────────────────────────────

class SurfaceBottleneckHeads(nn.Module):
    """
    Consistency classifiers that operate on *surface-level* (softmax probability)
    distributions at explanation token positions rather than hidden states.

    This forces the consistency signal to flow through the LM output distribution
    at explanation positions, creating a bottleneck where the explanation tokens'
    surface form must carry the claim signal.

    Specifically:
      1. Take lm_logits at explanation target positions  (B, T_expl, V)
      2. Apply softmax to get probability distributions   (B, T_expl, V)
      3. Mean-pool across explanation tokens              (B, V)
      4. Project via linear head to class logits          (B, n_classes)

    Gradients flow back into lm_logits (NO detach), so the consistency loss
    directly shapes the explanation token probability distributions.
    """

    def __init__(self, vocab_size: int, n_time: int, n_space: int, n_correct: int):
        super().__init__()
        self.time_head    = nn.Linear(vocab_size, n_time)
        self.space_head   = nn.Linear(vocab_size, n_space)
        self.correct_head = nn.Linear(vocab_size, n_correct)

    def forward(
        self,
        lm_logits: torch.Tensor,        # (B, T, V) — full sequence logits
        explanation_mask: torch.Tensor, # (B, T) bool — True at explanation positions
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns (time_logits, space_logits, correct_logits), each (B, n_classes).
        Gradients flow through softmax into lm_logits.
        """
        B, T, V = lm_logits.shape

        # Softmax over vocab at each position (differentiable, no detach)
        probs = F.softmax(lm_logits, dim=-1)  # (B, T, V)

        # Mean-pool only over explanation token positions
        mask_f = explanation_mask.float().unsqueeze(-1)  # (B, T, 1)
        denom  = mask_f.sum(dim=1).clamp(min=1.0)        # (B, 1)
        pooled = (probs * mask_f).sum(dim=1) / denom      # (B, V)

        return (
            self.time_head(pooled),
            self.space_head(pooled),
            self.correct_head(pooled),
        )


# ──────────────────────────────────────────────────────────────────────────────
# 3. Main model
# ──────────────────────────────────────────────────────────────────────────────

class ConsistencyTransformer(nn.Module):
    """
    GPT-2-style causal Transformer with:
      (a) standard LM head for next-token prediction
      (b) consistency head: linear classifiers over pooled explanation
          hidden states from the last layer
      (c) optional surface_heads (SurfaceBottleneckHeads): classifiers over
          mean-pooled softmax distributions at explanation positions

    Causal masking semantics:
      Explanation tokens come BEFORE claim tokens in the sequence.
      Standard causal attention already ensures explanation tokens
      cannot attend to (future) claim tokens.
      The `extra_mask` argument provides additional structural enforcement.

    Original variants:
      consistency_loss         : pool explanation token hiddens → classify
      no_consistency_loss      : same architecture, consistency loss weight = 0
      claim_only_pooling       : pool claim token hiddens instead of explanation
      random_label_consistency : pool explanation hiddens but use shuffled labels

    New stronger ablation variants (v2):
      no_claim_to_claim_attention     : blocks claim→claim attention in mask
      claims_from_explanation_only    : claims attend only explanation tokens
      surface_bottleneck_consistency  : consistency via surface LM-logit distributions
      surface_bottleneck_no_expl_lm   : surface bottleneck + no LM loss on expl tokens
    """

    def __init__(self, cfg: TransformerConfig, use_surface_heads: bool = False):
        super().__init__()
        self.cfg = cfg
        self.use_surface_heads = use_surface_heads

        # Embedding layers
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_id)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.emb_drop = nn.Dropout(cfg.dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_f   = nn.LayerNorm(cfg.d_model)

        # LM head
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

        # Hidden-state consistency heads (one per claim type)
        self.time_head    = nn.Linear(cfg.d_model, cfg.n_time_classes)
        self.space_head   = nn.Linear(cfg.d_model, cfg.n_space_classes)
        self.correct_head = nn.Linear(cfg.d_model, cfg.n_correct_classes)

        # Surface-bottleneck heads: operate on LM logit distributions, not hidden states
        # Instantiated always so model is pickle-able; only used when use_surface_heads=True
        self.surface_heads = SurfaceBottleneckHeads(
            vocab_size=cfg.vocab_size,
            n_time=cfg.n_time_classes,
            n_space=cfg.n_space_classes,
            n_correct=cfg.n_correct_classes,
        )

        # Weight tying
        self.lm_head.weight = self.tok_emb.weight

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)
                if m.padding_idx is not None:
                    nn.init.zeros_(m.weight[m.padding_idx])

    def forward(
        self,
        input_ids: torch.Tensor,           # (B, T)
        explanation_mask: Optional[torch.Tensor] = None,  # (B, T) bool: True = expl token
        claim_mask: Optional[torch.Tensor] = None,         # (B, T) bool: True = claim token
        pool_claims: bool = False,         # if True, pool claim tokens (negative control)
        extra_attn_mask: Optional[torch.Tensor] = None,   # (B, T, T) additive float bias
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
            lm_logits   : (B, T, V)
            time_logits : (B, n_time_classes)
            space_logits: (B, n_space_classes)
            correct_logits: (B, n_correct_classes)

        When use_surface_heads=True (set at construction time), the consistency
        logits are produced by SurfaceBottleneckHeads from lm_logits, not hidden
        states. Gradients flow through the softmax into lm_logits.
        """
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)

        x = self.emb_drop(self.tok_emb(input_ids) + self.pos_emb(pos))

        for block in self.blocks:
            x = block(x, extra_mask=extra_attn_mask)

        x = self.ln_f(x)

        # LM logits
        lm_logits = self.lm_head(x)

        # ── Surface-bottleneck path ───────────────────────────────────────────
        if self.use_surface_heads:
            # Use LM output distributions (softmax of logits) at explanation
            # positions to derive consistency logits.  No hidden-state pooling.
            expl_mask_for_surface = (
                explanation_mask
                if explanation_mask is not None
                else torch.ones(B, T, dtype=torch.bool, device=x.device)
            )
            time_logits, space_logits, correct_logits = self.surface_heads(
                lm_logits, expl_mask_for_surface
            )
            return lm_logits, time_logits, space_logits, correct_logits

        # ── Hidden-state path (original + attention-mask variants) ────────────
        # Pool hidden states for consistency classification
        if pool_claims and claim_mask is not None:
            pool_mask = claim_mask
        elif explanation_mask is not None:
            pool_mask = explanation_mask
        else:
            # Fallback: pool all tokens
            pool_mask = torch.ones(B, T, dtype=torch.bool, device=x.device)

        # Mean pooling over selected tokens
        pool_mask_f = pool_mask.float().unsqueeze(-1)  # (B, T, 1)
        denom = pool_mask_f.sum(dim=1).clamp(min=1.0)  # (B, 1)
        pooled = (x * pool_mask_f).sum(dim=1) / denom  # (B, d_model)

        time_logits    = self.time_head(pooled)
        space_logits   = self.space_head(pooled)
        correct_logits = self.correct_head(pooled)

        return lm_logits, time_logits, space_logits, correct_logits

    def num_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# ──────────────────────────────────────────────────────────────────────────────
# 4. Loss computation
# ──────────────────────────────────────────────────────────────────────────────

def compute_lm_loss(
    lm_logits: torch.Tensor,   # (B, T, V)
    input_ids: torch.Tensor,   # (B, T)
    pad_id: int,
) -> torch.Tensor:
    """
    Standard next-token prediction loss on the full sequence.
    Targets are shifted: predict token[i+1] from token[i].
    """
    B, T, V = lm_logits.shape
    # Shift: predict positions 1..T from positions 0..T-1
    logits  = lm_logits[:, :-1, :].contiguous().view(-1, V)
    targets = input_ids[:, 1:].contiguous().view(-1)
    loss = F.cross_entropy(logits, targets, ignore_index=pad_id)
    return loss


def compute_lm_loss_masked(
    lm_logits: torch.Tensor,          # (B, T, V)
    input_ids: torch.Tensor,          # (B, T)
    pad_id: int,
    lm_loss_mask: Optional[torch.Tensor] = None,  # (B, T) bool: True = compute loss here
) -> torch.Tensor:
    """
    Next-token LM loss with an optional per-position mask.

    When lm_loss_mask is provided, only positions where lm_loss_mask is True
    contribute to the loss.  This enables ablations like:
      - Disable LM loss on mismatched explanation tokens
        (set lm_loss_mask = ~explanation_mask, keeping code + claim positions)
      - Full sequence loss (lm_loss_mask=None, same as compute_lm_loss)

    Shifting: position t in lm_logits predicts position t+1 in input_ids.
    The mask is also shifted: lm_loss_mask[:, 1:] selects target positions 1..T-1.
    If no positions survive masking (degenerate batch), returns zero loss.
    """
    B, T, V = lm_logits.shape
    logits  = lm_logits[:, :-1, :].contiguous().view(-1, V)   # (B*(T-1), V)
    targets = input_ids[:, 1:].contiguous().view(-1)           # (B*(T-1),)

    if lm_loss_mask is None:
        return F.cross_entropy(logits, targets, ignore_index=pad_id)

    # Shift the mask: target position t+1 is predicted from logit position t
    # lm_loss_mask[:, 1:] has shape (B, T-1) — True at positions we include
    # Combine with pad masking: exclude pad targets regardless
    mask_shifted = lm_loss_mask[:, 1:].contiguous().view(-1)  # (B*(T-1),)
    pad_mask     = (targets != pad_id)                         # (B*(T-1),)
    active       = mask_shifted & pad_mask                     # (B*(T-1),)

    if active.sum() == 0:
        # No active positions — return zero loss (differentiable)
        return (logits * 0.0).sum() * 0.0

    loss = F.cross_entropy(
        logits[active],
        targets[active],
        ignore_index=pad_id,
    )
    return loss


def compute_consistency_loss(
    time_logits:    torch.Tensor,  # (B, 3)
    space_logits:   torch.Tensor,  # (B, 3)
    correct_logits: torch.Tensor,  # (B, 2)
    time_labels:    torch.Tensor,  # (B,)
    space_labels:   torch.Tensor,  # (B,)
    correct_labels: torch.Tensor,  # (B,)
) -> torch.Tensor:
    """
    Consistency loss: cross-entropy for each claim type, averaged.
    """
    loss_t = F.cross_entropy(time_logits,    time_labels)
    loss_s = F.cross_entropy(space_logits,   space_labels)
    loss_c = F.cross_entropy(correct_logits, correct_labels)
    return (loss_t + loss_s + loss_c) / 3.0


def total_loss(
    lm_loss:           torch.Tensor,
    consistency_loss:  torch.Tensor,
    lambda_consistency: float = 1.0,
) -> torch.Tensor:
    return lm_loss + lambda_consistency * consistency_loss


# ──────────────────────────────────────────────────────────────────────────────
# 5. Build causal extra-mask for explanation→claim blocking
# ──────────────────────────────────────────────────────────────────────────────

def build_explanation_claim_mask(
    explanation_mask: torch.Tensor,  # (B, T) bool
    claim_mask: torch.Tensor,        # (B, T) bool
) -> torch.Tensor:
    """
    Returns additive float mask of shape (B, T, T) where
    explanation token i attending to claim token j is set to -inf.
    This enforces: explanation tokens cannot attend to claim tokens.
    (Causal attention already prevents it for j > i, but this makes it explicit.)
    """
    B, T = explanation_mask.shape
    # mask[b, i, j] = -inf iff token i is an explanation token AND token j is a claim token
    # In autoregressive order, claims come after explanations, so causal masking
    # already handles this. This extra mask is provided for structural clarity.
    expl = explanation_mask.float().unsqueeze(2)   # (B, T, 1) — query is explanation
    claim = claim_mask.float().unsqueeze(1)          # (B, 1, T) — key is claim
    block = expl * claim                              # (B, T, T): 1 where to block
    mask = block * (-1e9)                             # replace 1s with -inf
    return mask


def build_no_claim_to_claim_mask(
    claim_mask: torch.Tensor,  # (B, T) bool
) -> torch.Tensor:
    """
    Returns additive float mask (B, T, T) that blocks claim query positions
    from attending to *other* claim key positions.  Under causal attention,
    a claim token at position i can still attend to itself (j == i), but
    causal masking already handles j > i.  This mask additionally blocks
    j < i when both i and j are claim positions.

    Effect: claim tokens attend to code + explanation tokens only
    (plus themselves under the causal diagonal), not previous claim tokens.
    This ablates cross-claim-span information flow.
    """
    B, T = claim_mask.shape
    # Block: query is claim (i), key is claim (j) AND j != i
    q_claim = claim_mask.float().unsqueeze(2)  # (B, T, 1)
    k_claim = claim_mask.float().unsqueeze(1)  # (B, 1, T)
    block   = q_claim * k_claim                # (B, T, T): both claim
    # Do NOT block the diagonal (self-attention at current position)
    diag    = torch.eye(T, device=claim_mask.device).bool().unsqueeze(0)  # (1, T, T)
    block   = block * (~diag).float()           # clear diagonal
    return block * (-1e9)


def build_claims_from_explanation_only_mask(
    explanation_mask: torch.Tensor,  # (B, T) bool
    claim_mask: torch.Tensor,        # (B, T) bool
) -> torch.Tensor:
    """
    Strict flow mask: claim query tokens may only attend to:
      • Explanation key tokens (strictly)
      • Their own current position (diagonal, allowed by causal)

    This is implemented by blocking:
      1. Claim queries from attending code key positions
      2. Claim queries from attending other claim key positions
         (including previous claim positions that would be visible under causal)

    Code positions are defined as: ~(explanation_mask | claim_mask)
    (i.e., all tokens that are neither explanation nor claim: BOS, code, SEP, PAD).

    The combination of this mask + causal mask means:
      - Claim token i can attend to explanation tokens j < i  (allowed)
      - Claim token i CANNOT attend to code / BOS / SEP tokens (blocked here)
      - Claim token i CANNOT attend to other claim tokens j < i (blocked here)
      - Claim token i can attend to itself (diagonal, not blocked)
      - Future positions already blocked by causal mask
    """
    B, T = explanation_mask.shape
    device = explanation_mask.device

    q_claim = claim_mask.float().unsqueeze(2)   # (B, T, 1) claim query
    k_claim = claim_mask.float().unsqueeze(1)   # (B, 1, T) claim key
    k_expl  = explanation_mask.float().unsqueeze(1)  # (B, 1, T) expl key

    # code_mask = 1 where token is not explanation and not claim (and not pad by virtue
    # of having some content — we treat ALL non-expl non-claim as "code" region)
    # We block claim queries from attending these
    k_code  = ((~explanation_mask) & (~claim_mask)).float().unsqueeze(1)  # (B, 1, T)

    # Block 1: claim query attends code key
    block_code  = q_claim * k_code

    # Block 2: claim query attends other claim key (but not itself)
    block_claim = q_claim * k_claim
    diag = torch.eye(T, device=device).bool().unsqueeze(0)  # (1, T, T)
    block_claim = block_claim * (~diag).float()

    total_block = (block_code + block_claim).clamp(max=1.0)

    # Also add the base explanation→claim block (keep original semantics for expl queries)
    k_claim_for_expl = claim_mask.float().unsqueeze(1)
    expl_q = explanation_mask.float().unsqueeze(2)
    block_expl_to_claim = expl_q * k_claim_for_expl

    total_block = (total_block + block_expl_to_claim).clamp(max=1.0)
    return total_block * (-1e9)


if __name__ == "__main__":
    cfg = TransformerConfig.smoke(vocab_size=500, pad_id=0)
    model = ConsistencyTransformer(cfg)
    print(f"Smoke model params: {model.num_parameters():,}")

    cfg2 = TransformerConfig.small(vocab_size=5000, pad_id=0)
    model2 = ConsistencyTransformer(cfg2)
    print(f"Small model params: {model2.num_parameters():,}")

    # Mock forward
    ids = torch.randint(1, 500, (4, 32))
    expl_mask = torch.zeros(4, 32, dtype=torch.bool)
    expl_mask[:, 5:20] = True
    claim_mask = torch.zeros(4, 32, dtype=torch.bool)
    claim_mask[:, 20:28] = True

    # Test original mask
    extra = build_explanation_claim_mask(expl_mask, claim_mask)
    lm_l, t_l, s_l, c_l = model(ids, expl_mask, claim_mask, extra_attn_mask=extra)
    print(f"LM logits: {lm_l.shape}, time: {t_l.shape}")

    # Test no_claim_to_claim mask
    mask_no_c2c = build_no_claim_to_claim_mask(claim_mask)
    print(f"no_claim_to_claim mask shape: {mask_no_c2c.shape}")

    # Test claims_from_explanation_only mask
    mask_cfo = build_claims_from_explanation_only_mask(expl_mask, claim_mask)
    print(f"claims_from_explanation_only mask shape: {mask_cfo.shape}")

    # Test surface bottleneck model
    model_surf = ConsistencyTransformer(cfg, use_surface_heads=True)
    lm_l2, t_l2, s_l2, c_l2 = model_surf(ids, expl_mask, claim_mask, extra_attn_mask=extra)
    print(f"Surface bottleneck - LM: {lm_l2.shape}, time: {t_l2.shape}")

    # Test masked LM loss
    lm_mask = ~expl_mask  # disable loss on explanation tokens
    loss_masked = compute_lm_loss_masked(lm_l, ids, pad_id=0, lm_loss_mask=lm_mask)
    loss_full   = compute_lm_loss(lm_l, ids, pad_id=0)
    print(f"LM loss full: {loss_full.item():.4f}, masked: {loss_masked.item():.4f}")
    print("All model.py self-tests passed.")
