#!/usr/bin/env bash
# Run the remaining 45 kernels of JAXBench's 50-kernel full suite with Gemini 3 Flash,
# for the three non-Autocomp baselines (Best-of-N, iterative, iterative+context).
# Skips the 5 kernels that already have results under output/baselines-flash/*.
#
# Budget per method per benchmark = 144 samples, matching the 5-kernel subset + the
# Autocomp sweep already in output/jaxbench-sweep-flash:
#   - Iterative / iterative+context: --num_chains 18 --turns 8
#   - Best-of-N:                     --n 144
#
# Output dirs:
#   output/baselines-flash/iterative/{prob_id}
#   output/baselines-flash/iterative_context/{prob_id}
#   output/baselines-flash/best_of_n/{prob_id}
#
# Dry-run: BASELINES_DRYRUN=1 ./run_50kernel_flash_baselines.sh
set -euo pipefail
cd /path/to/autocomp
export AUTOCOMP_JAXBENCH_PROFILE=1

MODEL="gemini-3-flash-preview"
OUTBASE="output/baselines-flash"

# All 50 JAXBench kernels minus the 5 already-completed Flash subset.
KERNELS=(
    2p_GQA_Attention
    3p_MLA_Attention
    4p_Sparse_Attention
    6p_Paged_Attention
    7p_Ragged_Paged_Attention
    8p_GEMM
    11p_Megablox_GMM
    9p_SwiGLU_MLP
    10p_Sparse_MoE
    13p_Cross_Entropy
    14p_Ragged_Dot
    17p_Triangle_Multiplication
    18k_Conv2D_ReLU_BiasAdd
    19k_Matmul_Subtract_Multiply_ReLU
    20k_Gemm_Multiply_LeakyReLU
    21k_Gemm_Divide_Sum_Scaling
    22k_Conv2d_InstanceNorm_Divide
    23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp
    24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish
    25k_Conv3d_GroupNorm_Mean
    26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply
    27k_Matmul_Mish_Mish
    28k_ConvTranspose3d_LayerNorm_GELU_Scaling
    29k_Matmul_Swish_Sum_GroupNorm
    30k_Matmul_Scaling_ResidualAdd
    31k_Gemm_BatchNorm_GELU_ReLU
    32k_Gemm_Sigmoid_LogSumExp
    33k_Conv3d_Mish_Tanh
    34k_Conv2d_Activation_BatchNorm
    35k_Gemm_Scaling_Hardtanh_GELU
    36k_Matmul_Sigmoid_Sum
    37k_Matmul_Swish_Scaling
    38k_Matmul_Dropout_Softmax
    39k_Conv2d_GELU_GlobalAvgPool
    40k_Gemm_GroupNorm_Min_BiasAdd
    41k_Gemm_Add_ReLU
    42k_Gemm_Max_Subtract_GELU
    43k_Gemm_BatchNorm_Scaling_Softmax
    44k_Matmul_Divide_GELU
    45k_Gemm_GroupNorm_Swish_Multiply_Swish
    46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp
    47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh
    48k_Matmul_BatchNorm_BiasAdd_Divide_Swish
    49k_Matmul_AvgPool_GELU_Scale_Max
    50k_Matmul_GELU_Softmax
)

DRYRUN="${BASELINES_DRYRUN:-0}"

run() {
    if [[ "$DRYRUN" == "1" ]]; then
        printf '    DRYRUN: %s\n' "$*"
    else
        "$@"
    fi
}

echo "============================================"
echo "Remaining 45-kernel Flash baseline sweep"
echo "Model: $MODEL"
echo "Budget: 144 samples/method/benchmark"
echo "Methods: iterative, iterative+context, best_of_n"
echo "Dry-run: $DRYRUN"
echo "============================================"

# --- Iterative: 18 chains x 8 turns = 144 samples ---
echo ""
echo "### Phase 1/3: iterative (45 kernels) ###"
for k in "${KERNELS[@]}"; do
    OUT="$OUTBASE/iterative/$k"
    if [[ -f "$OUT/summary.json" ]]; then
        echo ">>> iterative $k  SKIP (summary.json exists)"
        continue
    fi
    echo ">>> iterative $k"
    run python -m autocomp.baselines.iterative \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --num_chains 18 \
        --turns 8 \
        --model "$MODEL" \
        --output_dir "$OUT"
done

# --- Iterative+context ---
echo ""
echo "### Phase 2/3: iterative+context (45 kernels) ###"
for k in "${KERNELS[@]}"; do
    OUT="$OUTBASE/iterative_context/$k"
    if [[ -f "$OUT/summary.json" ]]; then
        echo ">>> iterative_context $k  SKIP (summary.json exists)"
        continue
    fi
    echo ">>> iterative_context $k"
    run python -m autocomp.baselines.iterative \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --num_chains 18 \
        --turns 8 \
        --model "$MODEL" \
        --context full \
        --agent_dir built:tpu-v6e \
        --output_dir "$OUT"
done

# --- Best-of-N: 144 independent samples ---
echo ""
echo "### Phase 3/3: best_of_n (45 kernels) ###"
for k in "${KERNELS[@]}"; do
    OUT="$OUTBASE/best_of_n/$k"
    if [[ -f "$OUT/summary.json" ]]; then
        echo ">>> best_of_n $k  SKIP (summary.json exists)"
        continue
    fi
    echo ">>> best_of_n $k"
    run python -m autocomp.baselines.best_of_n \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --n 144 \
        --model "$MODEL" \
        --output_dir "$OUT"
done

echo ""
echo "============================================"
echo "DONE"
echo "============================================"
