#!/usr/bin/env bash
# Run the 5-kernel subset with Gemini 3 Flash (Iterative + Autocomp + Best-of-N).
# Run sequentially — one at a time on the TPU.
#
# Sample budget per method per benchmark = 144 (matches Pro baselines in the paper):
#   - Iterative:  --num_chains 18 --turns 8          -> 144 samples
#   - Autocomp:   beam 3 * (3 plan * 2 code) * 8 iter = 144 samples
#   - Best-of-N:  --n 144                            -> 144 samples
#
# Output dirs are isolated from the Pro and Flash Lite results:
#   output/baselines-flash/iterative/{prob_id}
#   output/baselines-flash/best_of_n/{prob_id}
#   output/jaxbench-sweep-flash/{prob_id}_baseline[_translate]
set -euo pipefail
cd /path/to/autocomp
export AUTOCOMP_JAXBENCH_PROFILE=1

MODEL="gemini-3-flash-preview"
OUTBASE="output/baselines-flash"
AUTOCOMP_OUT="output/jaxbench-sweep-flash"

KERNELS=(
    "12p_RMSNorm"
    "5p_Flex_Attention"
    "15p_RetNet_Retention"
    "16p_Mamba2_SSD"
    "1p_Flash_Attention"
)

echo "============================================"
echo "5-kernel subset with Gemini 3 Flash"
echo "Model: $MODEL"
echo "Budget: 144 samples/method/benchmark"
echo "============================================"

# --- Iterative: 18 chains x 8 turns = 144 samples ---
for k in "${KERNELS[@]}"; do
    echo ""
    echo ">>> Iterative: $k  (18 chains x 8 turns = 144 samples)"
    python -m autocomp.baselines.iterative \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --num_chains 18 \
        --turns 8 \
        --model "$MODEL" \
        --output_dir "$OUTBASE/iterative/$k"
done

# --- Iterative+context: same algorithm, but with Autocomp's full agent profile
#     (architecture + ISA + code examples + rules) prepended. Isolates the
#     effect of context vs. beam search. ---
for k in "${KERNELS[@]}"; do
    echo ""
    echo ">>> Iterative+context: $k  (18 chains x 8 turns = 144 samples)"
    python -m autocomp.baselines.iterative \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --num_chains 18 \
        --turns 8 \
        --model "$MODEL" \
        --context full \
        --agent_dir built:tpu-v6e \
        --output_dir "$OUTBASE/iterative_context/$k"
done

# --- Autocomp ---
echo ""
echo ">>> Autocomp: all 5 kernels"
python run_batch.py \
    --probs "${KERNELS[@]}" \
    --models "gcp::$MODEL" \
    --output-base "$AUTOCOMP_OUT"

# --- Best-of-N: 144 independent samples ---
for k in "${KERNELS[@]}"; do
    echo ""
    echo ">>> Best-of-N: $k  (n=144)"
    python -m autocomp.baselines.best_of_n \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --n 144 \
        --model "$MODEL" \
        --output_dir "$OUTBASE/best_of_n/$k"
done

echo ""
echo "============================================"
echo "DONE: 5-kernel subset with Gemini 3 Flash"
echo "============================================"
