# Consistency Loss Experiment — Results Report

**Generated:** 2026-04-30 18:54:58

## 1. Experiment Configuration

### Full Configuration (target)
| Parameter | Value |
|---|---|
| Dataset size (full config) | 3,000 examples |
| Validation set (full config) | 500 examples |
| Epochs (full config) | 20 |
| Batch size (full config) | 32 |
| Learning rate | 5e-5 |
| Lambda (consistency weight) | 1.0 |
| Model (full config) | GPT-2-style Transformer, `small` config (~10M params); `gpt2_small` documented for ~117M |
| Optimizer | AdamW (weight decay=0.01, grad clip=1.0) |
| Checkpoint interval | every 5 epochs |

## 2. Model Architecture

The model is a **GPT-2-style causal Transformer** implemented from scratch
in PyTorch with identical causal masking semantics to GPT-2.

**Key design choices:**

- **Causal masking**: Standard lower-triangular attention mask.
  Explanation tokens appear *before* claim tokens in the sequence,
  so causal attention already prevents explanation tokens from attending
  to future claim tokens. An explicit additive attention bias further
  enforces this structural constraint.

- **Sequence format**: `<bos> [code] <sep> [explanation] <claim>time_complexity=X</claim>`
  `<claim>space_complexity=Y</claim> <claim>correctness=Z</claim> <eos>`

- **LM head**: Tied to token embeddings. Loss computed over full sequence
  (next-token prediction).

- **Consistency head**: Three linear classifiers (time complexity, space
  complexity, correctness) applied to *mean-pooled hidden states* of
  explanation tokens from the final Transformer layer.

- **Note on GPT-2 weights**: This implementation is a from-scratch
  GPT-2-compatible architecture. Loading pretrained GPT-2 weights would
  require the `transformers` library. The `gpt2_small` config (768 dim,
  12 heads, 12 layers, ~117M params) is provided but requires GPU with
  ≥8GB VRAM.

## 3. Experimental Variants

| Variant | Description |
|---|---|
| `consistency_loss` | Full mechanism: LM loss + consistency loss on explanation token pooling |
| `no_consistency_loss` | LM loss only; no gradient through consistency head |
| `claim_only_pooling` | Negative control: pool *claim* tokens instead of explanation tokens |
| `random_label_consistency` | Negative control: consistency loss with shuffled ground-truth labels |

**Critical prediction:** `consistency_loss` should develop stronger coupling
between explanation hidden states and ground-truth claims. Baselines should
fail to develop this coupling or show degraded explanation quality.

## 4. Final-Epoch Validation Metrics

*Metrics at epoch 20 (final epoch).*

| Variant | Coupling Strength | BLEU-1 | ROUGE-L | Swap Influence | Claim Accuracy | Val LM Loss |
|---|---|---|---|---|---|---|
| Consistency Loss | 1.0000 | 0.0673 | 0.0820 | 1.0000 | 1.0000 | 0.0718 |
| No Consistency Loss | 0.1747 | 0.0600 | 0.0791 | 0.0500 | 1.0000 | 0.0739 |
| Claim Only Pooling | 1.0000 | 0.0770 | 0.0969 | 0.8500 | 1.0000 | 0.0732 |
| Random Label Consistency | 0.6967 | 0.0665 | 0.0877 | -0.1500 | 1.0000 | 0.0732 |

## 5. Statistical Test: consistency_loss vs no_consistency_loss

**Metric**: BLEU-1 (explanation correctness proxy)
**Test**: Welch's independent-samples t-test (epochs > 10)

| | consistency\_loss | no\_consistency\_loss |
|---|---|---|
| Mean BLEU-1 | 0.0728 | 0.0707 |
| n epochs | 10 | 10 |

**t = 0.632, p = 0.5358** — not significant (p ≥ 0.05)

`consistency_loss` has **higher** mean BLEU-1 than `no_consistency_loss`.

## 6. Metric Trajectory Summary

### Consistency Loss

| Metric | Epoch 1 | Epoch 20 | Δ |
|---|---|---|---|
| Coupling Strength | 1.0000 | 1.0000 | +0.0000 |
| BLEU-1 | 0.0545 | 0.0673 | +0.0128 |
| ROUGE-L | 0.0525 | 0.0820 | +0.0295 |
| Swap Influence | 1.0000 | 1.0000 | +0.0000 |
| Claim Accuracy | 0.6333 | 1.0000 | +0.3667 |

### No Consistency Loss

| Metric | Epoch 1 | Epoch 20 | Δ |
|---|---|---|---|
| Coupling Strength | 0.1440 | 0.1747 | +0.0307 |
| BLEU-1 | 0.0623 | 0.0600 | -0.0023 |
| ROUGE-L | 0.0712 | 0.0791 | +0.0079 |
| Swap Influence | 0.0500 | 0.0500 | +0.0000 |
| Claim Accuracy | 0.7500 | 1.0000 | +0.2500 |

### Claim Only Pooling

| Metric | Epoch 1 | Epoch 20 | Δ |
|---|---|---|---|
| Coupling Strength | 1.0000 | 1.0000 | +0.0000 |
| BLEU-1 | 0.0335 | 0.0770 | +0.0436 |
| ROUGE-L | 0.0363 | 0.0969 | +0.0606 |
| Swap Influence | 1.0000 | 0.8500 | -0.1500 |
| Claim Accuracy | 0.6667 | 1.0000 | +0.3333 |

### Random Label Consistency

| Metric | Epoch 1 | Epoch 20 | Δ |
|---|---|---|---|
| Coupling Strength | 0.7047 | 0.6967 | -0.0080 |
| BLEU-1 | 0.0477 | 0.0665 | +0.0188 |
| ROUGE-L | 0.0512 | 0.0877 | +0.0364 |
| Swap Influence | 0.0500 | -0.1500 | -0.2000 |
| Claim Accuracy | 0.5167 | 1.0000 | +0.4833 |

## 7. Qualitative Examples: Epoch-1 vs Final Epoch Generations

The following examples are drawn from the `consistency_loss` variant.
They show the model's generated explanation at epoch 1 (essentially random,
as the model just started training) versus the final epoch.

> **Smoke-run note**: With a tiny model and few epochs, generations are
> short and may not yet form coherent prose. The progression from epoch 1
> to the final epoch demonstrates that training is occurring and the
> model is adapting, even if fluency is limited.

### Example 1: `def check_all_pairs_...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Searches a list sequentially for a target value. O(n) time, O(1) space.

**True explanation (reference):**
> Checks if all elements are equal via pairwise comparison. O(n^2) time, O(1) space.

**Code snippet:**
```python
def check_all_pairs_equal(lst):
    for i in range(len(lst)):
        for j in range(i + 1, len(lst)):
            if lst[i] != lst[j]:
                return False
    return True
```

**Epoch-1 generation:**
> `C o m p u t e s n ! i t er a t i v e l y . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s n ! i t er a t i v e l y . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 2: `def generate_all_pai...`

**Ground-truth claims:** time=O(n^2), space=O(n^2), correct=1

**Mismatched explanation (training input):**
> Builds a frequency map in O(n) time and O(n) space.

**True explanation (reference):**
> Generates all ordered pairs. O(n^2) time and O(n^2) space.

**Code snippet:**
```python
def generate_all_pairs(lst):
    pairs = []
    for i in range(len(lst)):
        for j in range(len(lst)):
            pairs.append((lst[i], lst[j]))
    return pairs
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s o p a t e d in a s ce n d in a s . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(n^2) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `Re t u r n s u n i q u e e l e m e n t s p r e se r v in g o r d er . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(n^2) </claim> <claim> correctness=1 </claim>`

---

### Example 3: `def selection_sort(l...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Detects duplicates with nested loops. O(n^2) time, O(1) space.

**True explanation (reference):**
> Selection sort: selects the minimum in O(n^2) nested passes, O(1) space.

**Code snippet:**
```python
def selection_sort(lst):
    for i in range(len(lst)):
        min_idx = i
        for j in range(i + 1, len(lst)):
            if lst[j] < lst[min_idx]:
                min_idx = j
        lst[i], ls
```

**Epoch-1 generation:**
> `C o m p u t e s a r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s t h e s u m of a list w i t h a s in g l e p a s s . T i m e c o m pl e x i t y O(n) , sp a ce O(1) . <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1`

---

### Example 4: `insertion_sort...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Detects duplicates with nested loops. O(n^2) time, O(1) space.

**True explanation (reference):**
> Selection sort: selects the minimum in O(n^2) nested passes, O(1) space.

**Code snippet:**
```python
def selection_sort(lst):
    for i in range(len(lst)):
        min_idx = i
        for j in range(i + 1, len(lst)):
            if lst[j] < lst[min_idx]:
                min_idx = j
        lst[i], ls
```

**Epoch-1 generation:**
> `C o m p u t e s a r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s t h e s u m of a list w i t h a s in g l e p a s s . T i m e c o m pl e x i t y O(n) , sp a ce O(1) . <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1`

---

### Example 5: `is_palindrome_char...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=0

**Mismatched explanation (training input):**
> Builds a frequency map in O(n) time and O(n) space.

**True explanation (reference):**
> Attempts 2x2 matrix multiplication but is buggy (missing accumulation). O(n^2) time, O(1) auxiliary space.

**Code snippet:**
```python
def matrix_multiply_2x2(A, B):
    C = [[0, 0], [0, 0]]
    for i in range(2):
        for j in range(2):
            C[i][j] = A[i][0] * B[0][j]  # bug: missing second term
    return C
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s s o r t e d in a s ce n d in g o r d er . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

**Final-epoch generation:**
> `C h e c k s if a s t r in g c o n t a in s a ch a r a c t er v i a l in e a r s c a n . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

---

### Example 6: `def matrix_multiply_...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=0

**Mismatched explanation (training input):**
> Attempts to swap two values but contains a logic bug, returning incorrect results.

**True explanation (reference):**
> Attempts 2x2 matrix multiplication but is buggy (missing accumulation). O(n^2) time, O(1) auxiliary space.

**Code snippet:**
```python
def matrix_multiply_2x2(A, B):
    C = [[0, 0], [0, 0]]
    for i in range(2):
        for j in range(2):
            C[i][j] = A[i][0] * B[0][j]  # bug: missing second term
    return C
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s s o r t e d in a s ce n d in g o r d er . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

**Final-epoch generation:**
> `C h e c k s if a s t r in g c o n t a in s a ch a r a c t er v i a l in e a r s c a n . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

---

### Example 7: `def is_sorted(lst):
...`

**Ground-truth claims:** time=O(n), space=O(1), correct=1

**Mismatched explanation (training input):**
> Returns unique elements preserving order. O(n) time and O(n) space.

**True explanation (reference):**
> Checks whether a list is sorted in ascending order. O(n) time, O(1) space.

**Code snippet:**
```python
def is_sorted(lst):
    for i in range(len(lst) - 1):
        if lst[i] > lst[i + 1]:
            return False
    return True
```

**Epoch-1 generation:**
> `C h e c k s w h er a list i t e v e n in t h e n in t h er i s o o d u s in g t o o p er a s in O(1) time and O(1) space. <claim> time_complexity=O(n) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s p r e f i x s u m a r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 8: `def insertion_sort(l...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Searches a list sequentially for a target value. O(n) time, O(1) space.

**True explanation (reference):**
> Insertion sort: O(n^2) worst-case time, O(1) space.

**Code snippet:**
```python
def insertion_sort(lst):
    for i in range(1, len(lst)):
        key = lst[i]
        j = i - 1
        while j >= 0 and lst[j] > key:
            lst[j + 1] = lst[j]
            j -= 1
        lst[j
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s o o p a s ce n d in g o r d er . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s p r e f i x s u m a r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 9: `def selection_sort(l...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Insertion sort: O(n^2) worst-case time, O(1) space.

**True explanation (reference):**
> Selection sort: selects the minimum in O(n^2) nested passes, O(1) space.

**Code snippet:**
```python
def selection_sort(lst):
    for i in range(len(lst)):
        min_idx = i
        for j in range(i + 1, len(lst)):
            if lst[j] < lst[min_idx]:
                min_idx = j
        lst[i], ls
```

**Epoch-1 generation:**
> `C o m p u t e s a r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s t h e s u m of a list w i t h a s in g l e p a s s . T i m e c o m pl e x i t y O(n) , sp a ce O(1) . <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1`

---

### Example 10: `def matrix_multiply_...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=0

**Mismatched explanation (training input):**
> Attempts to reverse a list in-place but has a bug: it swaps an element with itself.

**True explanation (reference):**
> Attempts 2x2 matrix multiplication but is buggy (missing accumulation). O(n^2) time, O(1) auxiliary space.

**Code snippet:**
```python
def matrix_multiply_2x2(A, B):
    C = [[0, 0], [0, 0]]
    for i in range(2):
        for j in range(2):
            C[i][j] = A[i][0] * B[0][j]  # bug: missing second term
    return C
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s s o r t e d in a s ce n d in g o r d er . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

**Final-epoch generation:**
> `C h e c k s if a s t r in g c o n t a in s a ch a r a c t er v i a l in e a r s c a n . O(n) time , O(1) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=0 </claim>`

---

### Example 11: `def sign(x):
    if ...`

**Ground-truth claims:** time=O(1), space=O(1), correct=1

**Mismatched explanation (training input):**
> Checks if all elements are equal via pairwise comparison. O(n^2) time, O(1) space.

**True explanation (reference):**
> Returns the sign of a number (-1, 0, or 1) in O(1) time and space.

**Code snippet:**
```python
def sign(x):
    if x > 0: return 1
    if x < 0: return -1
    return 0
```

**Epoch-1 generation:**
> `C o m p u t e s p r i x s u m a r a y . O(n) time and O(n) space. <claim> time_complexity=O(1) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s n ! i t er a t i v e l y . O(n) time , O(1) space. <claim> time_complexity=O(1) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 12: `def string_contains(...`

**Ground-truth claims:** time=O(n), space=O(1), correct=1

**Mismatched explanation (training input):**
> Reverses a string character by character. O(n) time, O(n) space due to string accumulation.

**True explanation (reference):**
> Checks if a string contains a character via linear scan. O(n) time, O(1) space.

**Code snippet:**
```python
def string_contains(s, ch):
    for c in s:
        if c == ch:
            return True
    return False
```

**Epoch-1 generation:**
> `C o m p u t e s p r i x . T h e f i m e t i m u m e y . T h e bu i l t - in O(n) ; o v er a l O(n) time , O(1) space. <claim> time_complexity=O(n) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s p r e f i x s u m a r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 13: `def all_pairs_sum(ls...`

**Ground-truth claims:** time=O(n^2), space=O(1), correct=1

**Mismatched explanation (training input):**
> Checks whether a list is sorted in ascending order. O(n) time, O(1) space.

**True explanation (reference):**
> Sums all pairs from the list. Double nested loop gives O(n^2) time, O(1) space.

**Code snippet:**
```python
def all_pairs_sum(lst):
    total = 0
    for i in range(len(lst)):
        for j in range(len(lst)):
            total += lst[i] + lst[j]
    return total
```

**Epoch-1 generation:**
> `C o m p u t e s p r e f i x s u m a r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s p r e f i x s u m a r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n^2) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 14: `def max_of_two(a, b)...`

**Ground-truth claims:** time=O(1), space=O(1), correct=1

**Mismatched explanation (training input):**
> Naive polynomial/array convolution. O(n^2) time, O(n) output space.

**True explanation (reference):**
> Returns the larger of two numbers using a conditional expression. O(1) time and space.

**Code snippet:**
```python
def max_of_two(a, b):
    return a if a >= b else b
```

**Epoch-1 generation:**
> `C h e c k s w h e t h er a list i s s o r t e d in a s ce n d in g o r d er . O(n) time , O(1) space. <claim> time_complexity=O(1) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C o m p u t e s n ! i t er a t i v e l y . O(n) time , O(1) space. <claim> time_complexity=O(1) </claim> <claim> space_complexity=O(1) </claim> <claim> correctness=1 </claim>`

---

### Example 15: `linear_max...`

**Ground-truth claims:** time=O(n), space=O(n), correct=1

**Mismatched explanation (training input):**
> Insertion sort: O(n^2) worst-case time, O(1) space.

**True explanation (reference):**
> Reverses a string character by character. O(n) time, O(n) space due to string accumulation.

**Code snippet:**
```python
def string_reverse(s):
    result = ''
    for ch in s:
        result = ch + result
    return result
```

**Epoch-1 generation:**
> `C o m p u t e s p r r a y . O(n) time and O(n) space. <claim> time_complexity=O(n) </claim> <claim> space_complexity=O(n) </claim> <claim> correctness=1 </claim>`

**Final-epoch generation:**
> `C h e c k s if a l l e l e m e n t s a r e e q u a l v i a p a i r w i se c o m p a r i s o n . O( n ^ 2 ) time , O(1) space. <claim> time_complexity=O(n) </claim> <claim>`

---

## 8. Limitations and Interpretation

1. **Smoke run constraints**: The smoke configuration uses a tiny Transformer
   (~0.5–2M params), a reduced dataset, and few epochs. These constraints
   prevent the model from reaching the performance levels expected in the full run.

2. **Tokenizer**: A simple whitespace tokenizer is used (no BPE/SentencePiece).
   This means token sequences are longer than with a subword tokenizer, and
   the vocabulary may not generalize as well.

3. **BLEU/ROUGE proxy**: BLEU-1 and ROUGE-L are computed against ground-truth
   explanation templates, not against diverse human references. They measure
   whether the model recovers the training-set language, not open-ended quality.

4. **Claim emission accuracy**: Measured by string-matching in greedy-decoded
   output. A model could emit the correct claim token by memorizing without
   true generalization.

5. **Coupling vs causality**: Classifier accuracy on explanation hidden states
   measures *correlation*, not causal coupling. The full experiment with
   multiple random seeds and probing experiments would provide stronger evidence.

6. **Full 20-epoch run**: The full configuration (3,000 examples, 20 epochs,
   small model) requires approximately 2–4 hours on a modern GPU. The
   `gpt2_small` config (~117M params) would require ≥8GB GPU VRAM and
   significantly more compute.

## 9. Run Instructions

See `README.md` for full setup and run instructions.

```bash
# Smoke run (fast, ~2-5 min on CPU):
python run_experiment.py --smoke

# Full run (20 epochs, 3000 examples):
python run_experiment.py --full

# Small model, custom config:
python run_experiment.py --full --model small --epochs 20 --batch 32

# GPT-2-style config (GPU required):
python run_experiment.py --full --model gpt2_small
```
