# JAXBench Full-Suite Optimization Plan

Optimize all 50 JAXBench benchmarks on TPU v6e-1 (Trillium) with Autocomp.

## Benchmark Inventory

| Group | Count | Starting point | Autocomp `prob_type` |
|-------|------:|----------------|----------------------|
| A — Pallas-optimized priority kernels | 8 | `optimized.py` (hand-written Pallas) | `jaxbench-pallas` |
| B — Baseline-only priority kernels | 9 | `baseline.py` (vanilla JAX) | `jaxbench-baseline` |
| C — KernelBench fused operators | 33 | `baseline.py` (vanilla JAX) | `jaxbench-baseline` |
| D — End-to-end Pallas (from JAX) | 8 | `baseline.py` (vanilla JAX) | `jaxbench-baseline` |

Group D re-runs the 8 Group A benchmarks from `baseline.py` to measure end-to-end translation quality vs. the hand-written Pallas starting point.

**Total runs: 58** (8 + 9 + 33 + 8).

---

## Search Configuration

All 58 runs use the same `run_batch.py` configuration. Groups B/C/D use a **two-phase pipeline**: a translation phase followed by an optimization phase, each with its own output directory. Group A runs optimization only.

```python
backend_name = "jaxbench"
agent_name = "built:tpu-v6e"
hw_config = TpuHardwareConfig("v6e-1")
metric = "latency"

# --- Translation phase (Groups B, C, D only) ---
TRANSLATE_ITERATIONS = 4
TRANSLATE_USE_EDITS = False          # full code generation during translation
TRANSLATE_PERF_THRESHOLD = 15
TRANSLATE_DROP_ORIGINAL = True
TRANSLATE_SCORE = True

# --- Optimization phase (all groups) ---
OPT_ITERATIONS = 4
OPT_USE_EDITS = True                 # edit-based code generation during optimization

# --- Shared settings ---
search_strategy = "beam"
num_plan_candidates = 1              # skip_planning=True, so 1 "plan" per parent
num_code_candidates = 6              # 6 code samples per parent per iteration
beam_size = 3
dropout_menu_options = 0.25
early_stop_iters = 3                 # stop if no improvement for 3 consecutive iters
early_stop_threshold = 0.98          # ratio >= 0.98 triggers stop (< 2% improvement = stalled)
skip_planning = True                 # bypass plan phase, generate code directly
prevent_duplicate_level = -1         # allow same-parent candidates; eval is the bottleneck
reimplement_failed = False

# Reduced trial counts during search (final eval uses full 100 trials separately)
AUTOCOMP_TPU_NUM_WARMUP = 3
AUTOCOMP_TPU_NUM_TRIALS = 20

# Models
models = [
    "gcp::gemini-3.1-pro-preview",
    "gcp::gemini-3-flash-preview",
]
code_models = None

# Built-agent options
menu_strategy = "one-shot"
fine_grained_isa = True
example_rate = 0.25

# Feedback
give_score_feedback = 1
give_util_feedback = 0
give_hw_feedback = 0
include_ancestors = False
```

### The only parameter that varies between groups

| Group | `prob_type` | Pipeline | Effective iterations |
|-------|-------------|----------|----------------------|
| A | `jaxbench-pallas` | Optimize only | 8 optimization iterations (edits) |
| B | `jaxbench-baseline` | Translate → Optimize | 4 translation (full rewrite) + 4 optimization (edits) |
| C | `jaxbench-baseline` | Translate → Optimize | 4 translation (full rewrite) + 4 optimization (edits) |
| D | `jaxbench-baseline` | Translate → Optimize | 4 translation (full rewrite) + 4 optimization (edits) |

### Rationale

- **Two-phase pipeline** — translation (XLA baseline → Pallas) and optimization (improve Pallas code) are separate runs with separate output directories. The optimization phase continues from the best translation result. This isolates translation failures from optimization and allows resuming either phase independently.
- **4 + 4 iterations** — 4 translation iterations give enough budget for complex kernels; 4 optimization iterations refine the Pallas code. Based on Group B results, 8 total iterations capture >97% of the gains vs. 10 iterations. Early stopping at 3 stalled iterations with a 0.98 threshold ensures compute isn't wasted on converged benchmarks.
- **`translate_perf_threshold = 15`** — during translation, keep candidates up to 15× slower than baseline (correctness-first exploration).
- **`translate_score = True`** — enables translation completeness scoring and early stopping. When all beam candidates score >= 9.0, the remaining translation iterations are converted to optimization iterations.
- **Full rewrite for translation, edits for optimization** — translation requires restructuring the entire file (JAX → Pallas), so `use_edits=False`. Optimization applies localized changes to existing Pallas code, so `use_edits=True` avoids accidentally dropping unrelated code.
- **1 plan × 6 code × 3 beam** — `skip_planning=True` bypasses the plan phase; the LLM generates code directly. 6 code samples per parent × 3 beam parents = up to 18 candidates evaluated per iteration.
- **Two Gemini models (Pro + Flash)** — consistent model family for all runs; Pro handles complex reasoning, Flash provides diverse candidates at lower cost.
- **Reduced eval trials (3 warmup, 20 trials)** — speeds up search iterations. Final evaluation uses full 100 trials separately for accurate benchmarking.

---

## Group A — Improve Existing Pallas Kernels (8 runs) — NOT YET RUN

Start from `optimized.py` (hand-written Pallas). Optimization only (8 iterations, edits). The table below shows the existing human-written baselines; Autocomp has not yet been run on this group.

| # | `prob_id` | XLA Baseline (ms) | Human Pallas (ms) | Human Speedup |
|---|-----------|---:|---:|---:|
| 1 | `1p_Flash_Attention` | 1.49 | 0.62 | 2.4× |
| 2 | `2p_GQA_Attention` | 3.24 | 1.44 | 2.3× |
| 3 | `3p_MLA_Attention` | 4.48 | 4.28 | 1.0× |
| 4 | `4p_Sparse_Attention` | 1.50 | 0.66 | 2.3× |
| 5 | `6p_Paged_Attention` | 1.96 | 1.02 | 1.9× |
| 6 | `7p_Ragged_Paged_Attention` | 191.99 | 0.86 | 222× |
| 7 | `8p_GEMM` | 5.48 | 5.62 | 1.0× |
| 8 | `11p_Megablox_GMM` | 187.04 | 2.82 | 66× |

---

## Group B — Translate + Optimize Priority Kernels (9 runs)

Start from `baseline.py`. Two-phase pipeline (translate → optimize).

All benchmarks ran with the two-step flow (full code gen for translate, edits for optimize). Translation early-stopped on all benchmarks at iter 1–2 (all beam candidates reached translation score >= 9.0); remaining translation budget was converted to optimization iterations. Results below are capped at 8 total iterations (ran 5+5=10 originally; the extra 2 iterations yielded <3% additional improvement on the best benchmark).

"Post-Translate" = best Pallas latency from actual translation iterations only (before conversion to optimization). "Post-Optimize" = best Pallas latency after 8 total optimization iterations (including converted ones and the separate optimize phase).

| # | `prob_id` | Baseline (ms) | Post-Translate (ms) | Post-Optimize (ms) | Speedup | Best-of-N @144 | Iterative @144 | Runtime | Notes |
|---|-----------|---:|---:|---:|---:|---|---|---:|---|
| 1 | `5p_Flex_Attention` | 36.365 | 26.441 | **9.572** | **3.80x** | 0/144 correct | 9.05 ms (3.99x), 29/144 | 4.9h | Early-stop @1; opt drove gains |
| 2 | `9p_SwiGLU_MLP` | 16.800 | 26.651 | **16.136** | **1.04x** | 0/144 correct | — | 7.0h | Early-stop @1; translate much worse, opt clawed back |
| 3 | `10p_Sparse_MoE` | 36.549 | 108.660 | **33.838** | **1.08x** | 0/144 correct | — | 5.3h | Early-stop @2; translate 3x worse, opt recovered |
| 4 | `12p_RMSNorm` | 1.206 | 0.884 | **0.859** | **1.40x** | 0.86 ms (1.42x), 10/144 | 0.86 ms (1.42x), 67/144 | 2.2h | Early-stop @1; translation already faster than XLA |
| 5 | `13p_Cross_Entropy` | 14.372 | 15.706 | **11.051** | **1.30x** | — | — | 3.4h | Early-stop @2; translate slightly worse, opt recovered |
| 6 | `14p_Ragged_Dot` | 1.361 | 4.590 | **1.463** | **0.93x** | — | — | 3.2h | Early-stop @1; Pallas never beat XLA |
| 7 | `15p_RetNet_Retention` | 12.910 | 2.720 | **1.950** | **6.62x** | 0/144 correct | 2.63 ms (4.85x), 1/144 | 2.5h | Early-stop @1; translation already 4.7x faster |
| 8 | `16p_Mamba2_SSD` | 29.041 | 125.966 | **7.265** | **4.00x** | 0/144 correct | 0/144 correct | 2.7h | Early-stop @1; translate 4x worse, opt drove all gains |
| 9 | `17p_Triangle_Multiplication` | 14.489 | 24.228 | **7.683** | **1.89x** | — | — | 3.4h | Early-stop @2; translate worse, opt recovered |

**Group B summary:**
- **9 of 9** benchmarks successfully translated from JAX/XLA to Pallas (100% translation success rate)
- **8 of 9** benchmarks improved over XLA baseline (speedup > 1.0x)
- **4 big wins** (>1.5x): RetNet Retention (6.62x), Mamba2 SSD (4.00x), Flex Attention (3.80x), Triangle Multiplication (1.89x) — kernels where XLA materializes large intermediates that Pallas can tile/fuse
- **3 moderate wins**: RMSNorm (1.40x), Cross Entropy (1.30x), Sparse MoE (1.08x)
- **1 marginal**: SwiGLU MLP (1.04x) — XLA already highly optimized for this pattern
- **1 regression**: Ragged Dot (0.93x) — Pallas couldn't match XLA's grouped matmul; most likely a block size tuning issue
- Translation alone produced a correct-and-faster Pallas kernel for only 2 of 9 benchmarks (RMSNorm, RetNet). For the other 7, the initial Pallas translation was slower than XLA and required optimization iterations to improve.
- **Baseline comparison (4 benchmarks with data):** Best-of-N @144 only solves RMSNorm (1.42x); Iterative @144 solves RMSNorm (1.42x), Flex Attention (3.99x), and RetNet (4.85x) but fails on Mamba-2 SSD. Autocomp is the only method that solves all benchmarks.
- Full output artifacts: [`output/jaxbench-sweep/`](https://github.com/ucb-bar/autocomp/tree/jaxbench/output/jaxbench-sweep)
- **Total Group B runtime: ~35h** (2–7h per benchmark, median ~3.4h)

---

## Group C — Translate + Optimize KernelBench Fused Ops (33 runs)

Start from `baseline.py`. `translate_iters = 4`.

| Subgroup | Count | `prob_id` examples |
|----------|------:|---------|
| Matmul fused | 14 | `19k_Matmul_Subtract_Multiply_ReLU`, `30k_Matmul_Scaling_ResidualAdd`, ... |
| Gemm fused | 10 | `20k_Gemm_Multiply_LeakyReLU`, `41k_Gemm_Add_ReLU`, ... |
| Conv2d fused | 5 | `18k_Conv2D_ReLU_BiasAdd`, `39k_Conv2d_GELU_GlobalAvgPool`, ... |
| Conv3d fused | 2 | `25k_Conv3d_GroupNorm_Mean`, `33k_Conv3d_Mish_Tanh` |
| BMM fused | 1 | `26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply` |
| ConvTranspose fused | 1 | `28k_ConvTranspose3d_LayerNorm_GELU_Scaling` |

Full list of 33 `prob_id` values:

```
18k_Conv2D_ReLU_BiasAdd
19k_Matmul_Subtract_Multiply_ReLU
20k_Gemm_Multiply_LeakyReLU
21k_Gemm_Divide_Sum_Scaling
22k_Conv2d_InstanceNorm_Divide
23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp
24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish
25k_Conv3d_GroupNorm_Mean
26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply
27k_Matmul_Mish_Mish
28k_ConvTranspose3d_LayerNorm_GELU_Scaling
29k_Matmul_Swish_Sum_GroupNorm
30k_Matmul_Scaling_ResidualAdd
31k_Gemm_BatchNorm_GELU_ReLU
32k_Gemm_Sigmoid_LogSumExp
33k_Conv3d_Mish_Tanh
34k_Conv2d_Activation_BatchNorm
35k_Gemm_Scaling_Hardtanh_GELU
36k_Matmul_Sigmoid_Sum
37k_Matmul_Swish_Scaling
38k_Matmul_Dropout_Softmax
39k_Conv2d_GELU_GlobalAvgPool
40k_Gemm_GroupNorm_Min_BiasAdd
41k_Gemm_Add_ReLU
42k_Gemm_Max_Subtract_GELU
43k_Gemm_BatchNorm_Scaling_Softmax
44k_Matmul_Divide_GELU
45k_Gemm_GroupNorm_Swish_Multiply_Swish
46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp
47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh
48k_Matmul_BatchNorm_BiasAdd_Divide_Swish
49k_Matmul_AvgPool_GELU_Scale_Max
50k_Matmul_GELU_Softmax
```

---

## Group D — End-to-End Translation of Pallas Benchmarks (8 runs)

Same 8 benchmarks as Group A, but starting from `baseline.py`. `translate_iters = 4`.

This enables a direct comparison:
- Group A result = Autocomp improving a human-written Pallas kernel
- Group D result = Autocomp writing a Pallas kernel from scratch (JAX baseline only)
- Existing `optimized.py` = human-written reference point

---

## Execution Plan

### Infrastructure

- **TPU**: Single v6e-1 VM, JAX 0.9.2, Python 3.11
- **Runner**: `run_batch.py` — loops over `(prob_type, prob_id)` pairs, two-phase pipeline for baseline groups
- **Parallelism**: TPU evals serialized; LLM calls parallelized per model
- **Output**: Each run gets directories under `output/jaxbench-sweep/` (e.g., `1p_Flash_Attention_pallas/` for Group A, `5p_Flex_Attention_baseline_translate/` and `5p_Flex_Attention_baseline/` for Groups B/C/D)
- **Resume**: `--skip-completed` (default) skips benchmarks whose optimization directory has `candidates-iter-5/`; legacy single-phase runs are auto-detected

### Execution Order

1. **Group A** (8 runs) — improve existing Pallas kernels
2. **Group B** (9 runs) — translate priority kernels
3. **Group D** (8 runs) — end-to-end translation comparison
4. **Group C** (33 runs) — translate fused operators

### Time Estimate

With 18 candidates/iter (1 plan × 6 code × 3 beam) and reduced eval trials (20 trials):

| Group | Runs | Phases | Iters/run | Total iters | Est. time |
|-------|-----:|--------|----------:|------------:|----------:|
| A | 8 | opt only | 8 | 64 | ~32 hrs |
| B | 9 | tr + opt | 4 + 4 | 72 | ~36 hrs |
| D | 8 | tr + opt | 4 + 4 | 64 | ~32 hrs |
| C | 33 | tr + opt | 4 + 4 | 264 | ~132 hrs |
| **Total** | **58** | | | **464** | **~232 hrs (~10 days)** |

*Early stopping may reduce actual time — benchmarks that converge or translate quickly will finish in fewer iterations.*

### Success Metrics

| Metric | Definition |
|--------|------------|
| **Speedup** | `baseline_ms / autocomp_ms` for each benchmark |
| **Translation success rate** | Fraction of Groups B/C/D runs producing a correct Pallas kernel |
| **Geomean speedup** | Geometric mean speedup across each group |
| **Head-to-head (Group A vs D)** | Per-benchmark comparison of starting from Pallas vs. from JAX |

---

## Next Steps

1. [x] Create `run_batch.py` with resume support
2. [x] Verify TPU VM is accessible and JAX 0.9.2 is installed
3. [ ] Run Group A
4. [x] Run Group B
5. [ ] Run Group D
6. [ ] Run Group C
7. [ ] Collect results and compute metrics
