# Tabular Model — Usage Guide

This guide explains how to train, load, and evaluate the tabular model (TabularACE), and how to generate the comparison plots used in our analyses.

## Install

- uv (recommended):
  - `uv sync --group dev`
- pip (alt):
  - `python -m venv .venv && source .venv/bin/activate`
  - `pip install -e . && pip install ruff`

## Train

Two options exist: Hydra entrypoint or the online tabular script.

- Hydra (general):
  - `python train.py --config-name=train_tabular`
  - Override examples: `python train.py training.num_epochs=1 device=cpu`

- Online SCM training (recommended for tabular):
  - `python train_tabular_online.py --config-name=train_tabular_online`
  - Default config: `configs/train_tabular_online.yaml`
  - Checkpoints saved under `checkpoint.save_dir` (templated by timestamp)

Tip: Keep track of run names and save dirs so it’s easy to compare checkpoints later.

Note on running Python: if you use uv (recommended), invoke scripts as `uv run python <script> ...` instead of plain `python` so the right environment is used.

## Checkpoints

- Suggested layout: `checkpoints/<run_name>/best_model.pt` or the default run folder created by Hydra.
- You can also place stand‑alone files anywhere, e.g. `checkpoints/best/checkpoint_final_step_157000.pt`.
- When comparing models, keep separate output image paths to avoid overwriting.

## Quick Visual Evaluation (2D functions)

You have two evaluation scripts. Choose based on how strictly you want to match the original training architecture.

### Option A — Architecture‑Adaptive Loader

- Script: `scripts/eval_tabular_grid.py`
- Use when the checkpoint architecture may differ slightly from current code. The script infers the shape and loads the compatible subset.
- Outputs a side‑by‑side grid plot of “true target function” vs “model mean prediction”.

Examples:
- MLP‑SCM prior functions (default source):
  - `python scripts/eval_tabular_grid.py --checkpoint path/to/model.pt --source mlpscm --nfuncs 6 --out outputs/compare/runA/tabular_eval_grid_mlpscm.png`
- Synthetic 2D toy functions with uniform context selection:
  - `python scripts/eval_tabular_grid.py --checkpoint path/to/model.pt --source synth --uniform-context --nfuncs 6 --out outputs/compare/runA/tabular_eval_grid_synth_uniform.png`

Notes:
- Default output is `outputs/tabular_eval_grid.png` — use `--out` to avoid overwriting.

### Option B — Exact‑Architecture, Strict Loader

- Script: `scripts/eval_tabular_grid_exact.py`
- Builds the exact architecture we trained (embed_dim=128, concat CLS=2, backbone 12×4, ff_factor=2, ISAB=3×128, row encoder=3×4, etc.) and loads checkpoints with `strict=True`.
- Also exposes mask hyperparameters to match training behavior.

Recommended settings (match training):
- `--nc 400` (or larger), `--nb 32`, `--attending-chunks 16`, `--q-block-size 128`, `--kv-block-size 128`
- Optional smoothing: `--mc-samples 100` (averages 100 Monte Carlo samples per target)

Examples:
- SCM prior:
  - `python scripts/eval_tabular_grid_exact.py --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --source mlpscm --nfuncs 6 --nc 400 --nb 32 --attending-chunks 16 --q-block-size 128 --kv-block-size 128 --mc-samples 100 --out outputs/compare/new_run/tabular_eval_grid_mlpscm_nc400_mc100.png`
- Synthetic + uniform context:
  - `python scripts/eval_tabular_grid_exact.py --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --source synth --uniform-context --nfuncs 6 --nc 400 --nb 32 --attending-chunks 16 --q-block-size 128 --kv-block-size 128 --mc-samples 100 --out outputs/compare/new_run/tabular_eval_grid_synth_uniform_nc400_mc100.png`

## What the Plot Shows

- True panel: Interpolated surface from target points `(Xt, Yt)`.
- Predicted panel: Model’s mean prediction at targets.
- Context points are overlaid in both panels (colored by their y values).
- Color scaling is shared so intensities are comparable.

## Normalization (Correctness)

- We use `TabICLScaler` to match training semantics:
  - Fit only on context `(xc, yc)`.
  - Transform `xc/xb/xt` with the chosen x normalization (power/quantile, etc.).
  - Transform `yc/yb/yt` by z‑score using context y stats.
- Both truth and predictions are plotted in the same normalized space.
- You can change normalization with `--norm {power,quantile,quantile_rtdl,none}` to sanity‑check sensitivity.

## Avoid Overwriting Plots

- Always pass a unique `--out` filename, e.g. under `outputs/compare/<run_name>/...`.
  - Example layout:
    - `outputs/compare/old_run/tabular_eval_grid_mlpscm.png`
    - `outputs/compare/new_run/tabular_eval_grid_mlpscm.png`

## Tips for Good‑Looking Plots

- Increase context (`--nc`) and buffer (`--nb`) to provide more signal.
- Match mask hyperparameters to training: `--attending-chunks 16 --q-block-size 128 --kv-block-size 128`.
- Try `--uniform-context` for synthetic grids to ensure the context covers the domain.
- Add `--mc-samples 100` to smooth out artifacts by averaging samples per target.
- For smoother interpolation, increase targets `--nt` (default 512).

## Troubleshooting

- “Skipped/Missing keys” with `eval_tabular_grid.py`:
  - That script adapts to shape drift; the model still runs. Use the exact script if you need strict loading.
- Strict load fails in `eval_tabular_grid_exact.py`:
  - Ensure the checkpoint matches the documented architecture. The script mirrors embedder prefixes (`embedder.*`/`tabular_embedder.*`) but won’t resize tensors.
- TabICL not installed:
  - The scaler falls back to standard scaling (mean/std). You can still evaluate.

## Comparing Two Models

1) Put checkpoints under separate folders: `checkpoints/old_run/...`, `checkpoints/new_run/...`.
2) Generate plots for both runs with unique `--out` paths (see examples above).
3) Diff the resulting images side by side.

---

Questions or gaps? Open an issue or ping in the project chat with your command and any error output — we can extend the scripts as needed (e.g., raw‑space plotting or metrics like RMSE).

## Log‑Likelihood Evaluation (1D GP)

We provide a unified LL evaluator for TabularACE over three modes. Always run scripts via uv:

- General form: `uv run python scripts/eval_tabular_ll.py --mode <mode> --checkpoint <ckpt> [flags...]`
- Modes:
  - `evaluate_ll_independent`: context‑only, independent targets (exact mixture LL)
  - `evaluate_ll_ar`: teacher‑forcing AR with re‑encode (insert true target as context; no AR tokens)
  - `evaluate_ll_ar_buffer`: teacher‑forcing AR with buffer tokens (true target as buffer; uses AR tokens; K‑chunked)

Common flags
- `--device cpu|cuda`
- `--num-fns`, `--nc`, `--nt` (targets K; e.g., 32)
- `--noise` (observation noise in GP sampler)
- `--seed` (sets Python/NumPy/Torch RNG — use the same seed across modes to evaluate the same functions)
- `--batch-path` / `--save-batch` (optional: load/save a pre‑generated batch to guarantee identical functions across runs)
- `--attn-backend auto|flash|mem|math` (SDPA backend)

AR‑mode flags
- `--n-mc` (default 64): number of permutations per function for Monte Carlo over target orders
- `--K` (default 32; `evaluate_ll_ar_buffer` only): chunk size; must be ≤ `max_buffer_size` (32)

Examples (normalized LL; K=32)
- Independent (exact):
  - `uv run python scripts/eval_tabular_ll.py --mode evaluate_ll_independent --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 8 --nc 64 --nt 32 --noise 0.0 --seed 1`
- AR (re‑encode, teacher forcing):
  - `uv run python scripts/eval_tabular_ll.py --mode evaluate_ll_ar --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 8 --nc 64 --nt 32 --noise 0.0 --n-mc 64 --seed 1`
- AR buffer (teacher forcing, K‑chunked):
  - `uv run python scripts/eval_tabular_ll.py --mode evaluate_ll_ar_buffer --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 8 --nc 64 --nt 32 --noise 0.0 --n-mc 64 --K 32 --seed 1`

Tips
- Use the same `--seed` and identical data flags to evaluate the same GP functions across modes. For absolute reproducibility, generate once with `--save-batch outputs/gp_batch.pt` and reuse with `--batch-path outputs/gp_batch.pt`.
- Reported LL is in normalized space (fit on context). If you need original‑scale LL, apply a Jacobian correction by subtracting `\sum_t \log(\sigma_y)` per function.

## Sequence Sampling (1D GP)

Use the unified sampler to draw sequences under three decoding modes. Always run via uv:

- General form:
  - `uv run python scripts/sample_tabular_sequences.py --decode-mode <all|ar_reencode|ar_buffer> --checkpoint <ckpt> [flags...]`
- Shared flags:
  - `--device cpu|cuda`, `--num-fns`, `--nc`, `--nt`, `--noise`
  - `--seed` (use the same seed across modes for identical functions)
  - `--batch-path` / `--save-batch` (optional: load/save a pre‑generated batch)
  - `--attn-backend auto|flash|mem|math`
- Mode‑specific:
  - `all`: context‑only; uses head sampling with `--n-mc` samples per target, averaged as the predictive mean
  - `ar_reencode`: teacher‑forcing AR with re‑encode (true → context); MC via batch expansion with `--n-mc`
  - `ar_buffer`: teacher‑forcing AR with buffer tokens (true → buffer + AR token); MC via batch expansion with `--n-mc`, chunk size `--K` (default 32, must be ≤ max_buffer_size)

Examples
- Independent (context‑only), 64 samples per target:
  - `uv run python scripts/sample_tabular_sequences.py --decode-mode all --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 6 --nc 64 --nt 256 --noise 0.0 --n-mc 64 --seed 1 --out outputs/seq_all.png`
- AR re‑encode (teacher‑forcing), MC via batch expansion:
  - `uv run python scripts/sample_tabular_sequences.py --decode-mode ar_reencode --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 6 --nc 64 --nt 256 --n-mc 64 --seed 1 --out outputs/seq_ar_reencode.png`
- AR buffer (teacher‑forcing, K=32 chunk), MC via batch expansion:
  - `uv run python scripts/sample_tabular_sequences.py --decode-mode ar_buffer --checkpoint checkpoints/best/checkpoint_final_step_157000.pt --device cpu --num-fns 6 --nc 64 --nt 256 --n-mc 64 --K 32 --seed 1 --out outputs/seq_ar_buffer.png`

Notes
- Plots are in normalized space (fit on context). Use a consistent `--seed` to compare modes on identical functions.
- `--batch-path / --save-batch` lets you pre‑generate a batch of functions and reuse it across runs for strict apples‑to‑apples comparisons.

## Publication Plots (2×3 Composite)

This section documents the exact commands to reproduce the 2×3 composite figure we use in the paper, which shows:

- First two columns (2×2 grid): smoothed predictions (top row = GP prior; bottom row = SCM prior). Left column uses Autoregressive (AR); right column uses Ours (AR buffer). Styling uses green for AR and red for Ours.
- Third column (two panels): sampling-time vs. context size (Nc) for B=256 (TNPD‑AR vs. Ours) drawn with the same serif/white theme.

We break it into two steps: (A) generate the smoothed prediction files, and (B) render the composite.

### A. Generate Smoothed Predictions

Use the unified generator to create two smoothed `.pt` files — one for AR re‑encode (green) and one for AR buffer (red). These files contain xc/yc/xt/yt, the smoothed mean (`y_mean`), 95% CIs, and (optionally) a few trajectories for plotting.

AR (re‑encode; GP prior example)

```
uv run python scripts/generate_smoothed_sequences.py \
  --decode-mode ar_reencode \
  --checkpoint checkpoints/best/checkpoint_final_step_157000.pt \
  --device cpu \
  --num-fns 6 \
  --nc 12 \
  --nt 256 \
  --n-mc 256 \
  --seed 7 \
  --out outputs/tabular_model_smoothed_ar_sequences.pt
```

Ours (AR buffer; SCM prior example)

```
uv run python scripts/generate_smoothed_sequences.py \
  --decode-mode ar_buffer \
  --checkpoint checkpoints/best/checkpoint_final_step_157000.pt \
  --device cpu \
  --num-fns 6 \
  --nc 12 \
  --nt 256 \
  --n-mc 256 \
  --K 32 \
  --refresh-mode context_chunk \
  --seed 7 \
  --out outputs/tabular_model_smoothed_arbuffer_sequences.pt
```

Notes
- These produce the two files we’ll plug into the composite plot’s first two columns.
- If you want to use the buffer “context‑chunk” mini re‑encode at sampling time (within K without AR tokens), use the online sampler (below) and save your own paths before smoothing; by default the generator already outputs smoothed means suitable for the figure.

Optional (online sampling + smoothing for a single function)

```
uv run python scripts/sample_tabular_sequences.py \
  --decode-mode ar_buffer \
  --checkpoint checkpoints/best/checkpoint_final_step_157000.pt \
  --device cpu \
  --num-fns 1 \
  --nc 12 \
  --nt 256 \
  --n-mc 256 \
  --K 32 \
  --seed 7 \
  --refresh-mode context_chunk \
  --smooth-mean \
  --out outputs/seq_ar_buffer_ctxchunk_smoothed.png
```

### B. Render the 2×3 Composite

Make sure you have the sampling JSONs (compiled + triton) under `outputs/fast_times/results/`. Then run:

```
uv run python scripts/fast_times/plot_composite_2x3.py \
  --json outputs/fast_times/results/compiled_sampling.json \
        outputs/fast_times/results/triton_sampling.json \
  --B 256 \
  --quad outputs/tabular_model_smoothed_ar_sequences.pt \
        outputs/tabular_model_smoothed_arbuffer_sequences.pt \
        outputs/tabular_model_smoothed_ar_sequences.pt \
        outputs/tabular_model_smoothed_arbuffer_sequences.pt \
  --fn-index-left 0 \
  --fn-index-right 0 \
  --out outputs/fast_times/plots/composite_2x3.png
```

What this does
- First two columns: 2×2 grid of smoothed predictions with publication styling:
  - Top row = “GP prior”; bottom row = “SCM prior” (labels on far left).
  - Left column = “Autoregressive” (green); right column = “Ours” (red).
  - r'$y$' on leftmost plots; r'$x$' on bottom predictions; ticks cleaned; plus markers for context.
- Third column: sampling‑time vs Nc for B=256 (TNPD‑AR vs Ours) with log y‑axis and consistent theme. Only bottom has r'$N$' x‑label; y‑ticks shown in this column as reference.

Sampling‑only (2×3 grid)

```
uv run python scripts/fast_times/plot_grid_b256_two_methods.py \
  --json outputs/fast_times/results/compiled_sampling.json \
        outputs/fast_times/results/triton_sampling.json \
  --B 256 \
  --out outputs/fast_times/plots/b256_two_methods_grid.png
```

Color mapping
- Paired palette indices (consistent across scripts):
  - AR (green): dark = 3, light = 2
  - Ours (red): dark = 5, light = 4

Repro tips
- Use `--seed` to keep functions consistent across runs.
- If you regenerate smoothed `.pt` files, re-run the composite command with the same `--quad` paths.
- All plotting scripts use the same serif/white theme and clean axes for a paper‑ready look.
