# Stats or Facts: The Role of Contextual Structure and Diversity in Language Model Generalization (Anonymous Submission)

This repository contains the code used to generate the results in ***Stats or Facts: The Role of Contextual Structure and Diversity in Language Model Generalization*** (currently under double‑blind review).

---

## 1 ‑ What the code does

Trains a small Transformer on the synthetic setup described in Section 3 of the paper at a specified **diversity level**, and logs three metrics throughout training:

* `Loss_stat` – token‑level KL divergence on non‑knowledge tokens
* `Acc_pos`  – accuracy of predicting the correct knowledge token at the *designated* position
* `Acc_fact` – accuracy of recalling the correct fact *anywhere* in the completion

Generated figures (a subset only) are placed in `fig/` for reviewers’ convenience.


---

## 2 ‑ Quick‑start


```bash
# Train for 10000 iterations on a medium‑diversity setting (ood_frac = 0.5) for MC1Pos10 setup
python3 ./src/main.py \
    --num_iters 10000 \
    --V 203 \
    --M 100 \
    --n_embd 32 \
    --n_layer 4 \
    --n_head 4 \
    --seq_len 50 \
    --num_chain_tmpl 1 \
    --num_pos_tmpl 10 \
    --ood_frac 0.5 \
    --num_ckpt 10 \
    --skip_spectok
```

### Pre‑defined diversity settings

| Shorthand     | `--num_chain_tmpl` | `--num_pos_tmpl` |
| ------------- | ------------------ | ---------------- |
| **MC10Pos1**  | `10`               | `1`              |
| **MC1Pos10**  | `1`                | `10`             |
| **MC10Pos10** | `10`               | `10`             |

Replace the two flags in the command above to switch between these conditions.

---

## 3.1 ‑ Key command‑line arguments

| Flag               | Purpose                                                    |
| ------------------ | ---------------------------------------------------------- |
| `--num_iters`      | Number of training iterations                              |
| `--V`              | Total vocabulary size (`V_K + V_D`)                        |
| `--M`              | Size of fact set (`K` in the paper)                        |
| `--n_embd`         | Transformer embedding dimension                            |
| `--n_layer`        | Number of Transformer layers                               |
| `--n_head`         | Number of attention heads                                  |
| `--seq_len`        | Sequence length (`T`)                                      |
| `--num_chain_tmpl` | Number of Markov‑chain transition matrices                 |
| `--num_pos_tmpl`   | Number of position templates (index pairs)                 |
| `--ood_frac`       | Fraction of templates held out as OOD (`1 − DIV` in paper) |
| `--num_ckpt`       | Number of training checkpoints to save                     |
