#!/usr/bin/env bash
# Run the 5-kernel subset with Gemini 3.1 Flash Lite.
# Purpose: validate Flash Lite on the subset before committing to the full 50.
#
# Sample budget per method per benchmark = 144 (matches Pro baselines in the paper):
#   - Best-of-N:  --n 144                            -> 144 samples
#   - Iterative:  --num_chains 18 --turns 8          -> 144 samples
#   - Autocomp:   beam 3 * (3 plan * 2 code) = 18 per iter * 8 iters = 144 samples
#     (TRANSLATE_ITERATIONS=4 + OPT_ITERATIONS=4 = 8 iters; see run_batch.py)
#
# Output dirs are isolated from the Pro results so both can coexist:
#   output/baselines-flashlite/iterative/{prob_id}
#   output/baselines-flashlite/best_of_n/{prob_id}
#   output/jaxbench-sweep-flashlite/{prob_id}_baseline[_translate]
set -euo pipefail
cd /path/to/autocomp
export AUTOCOMP_JAXBENCH_PROFILE=1

MODEL="gemini-3.1-flash-lite-preview"
OUTBASE="output/baselines-flashlite"
AUTOCOMP_OUT="output/jaxbench-sweep-flashlite"

KERNELS=(
    "12p_RMSNorm"
    "5p_Flex_Attention"
    "15p_RetNet_Retention"
    "16p_Mamba2_SSD"
    "1p_Flash_Attention"
)

echo "============================================"
echo "5-kernel subset with Gemini 3.1 Flash Lite"
echo "Model: $MODEL"
echo "Budget: 144 samples/method/benchmark"
echo "============================================"

# --- Iterative first: closest head-to-head with Autocomp at the same
#     144-sample budget. Tells us if Flash Lite can iterate on errors
#     at all before we invest in Autocomp's more expensive runs. ---
for k in "${KERNELS[@]}"; do
    echo ""
    echo ">>> Iterative: $k  (18 chains x 8 turns = 144 samples)"
    python -m autocomp.baselines.iterative \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --num_chains 18 \
        --turns 8 \
        --model "$MODEL" \
        --output_dir "$OUTBASE/iterative/$k"
done

# --- Autocomp ---
echo ""
echo ">>> Autocomp: all 5 kernels"
python run_batch.py \
    --probs "${KERNELS[@]}" \
    --models "gcp::$MODEL" \
    --output-base "$AUTOCOMP_OUT"

# --- Best-of-N (cheapest floor; run last) ---
for k in "${KERNELS[@]}"; do
    echo ""
    echo ">>> Best-of-N: $k  (n=144)"
    python -m autocomp.baselines.best_of_n \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --n 144 \
        --model "$MODEL" \
        --output_dir "$OUTBASE/best_of_n/$k"
done

echo ""
echo "============================================"
echo "DONE: 5-kernel subset with Gemini 3.1 Flash Lite"
echo "============================================"
