#!/usr/bin/env bash
# Run best_of_n on the remaining 45-kernel JAXBench Flash suite using tpu-node-2,
# in parallel with the iterative/iterative_context sweep running on tpu-node.
#
# Skips any kernel that already has output/baselines-flash/best_of_n/{k}/summary.json.
set -euo pipefail
cd /path/to/autocomp

export AUTOCOMP_TPU_NAME=tpu-node-2
export AUTOCOMP_JAXBENCH_PROFILE=1

MODEL="gemini-3-flash-preview"
OUTBASE="output/baselines-flash"

KERNELS=(
    2p_GQA_Attention
    3p_MLA_Attention
    4p_Sparse_Attention
    6p_Paged_Attention
    7p_Ragged_Paged_Attention
    8p_GEMM
    11p_Megablox_GMM
    9p_SwiGLU_MLP
    10p_Sparse_MoE
    13p_Cross_Entropy
    14p_Ragged_Dot
    17p_Triangle_Multiplication
    18k_Conv2D_ReLU_BiasAdd
    19k_Matmul_Subtract_Multiply_ReLU
    20k_Gemm_Multiply_LeakyReLU
    21k_Gemm_Divide_Sum_Scaling
    22k_Conv2d_InstanceNorm_Divide
    23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp
    24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish
    25k_Conv3d_GroupNorm_Mean
    26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply
    27k_Matmul_Mish_Mish
    28k_ConvTranspose3d_LayerNorm_GELU_Scaling
    29k_Matmul_Swish_Sum_GroupNorm
    30k_Matmul_Scaling_ResidualAdd
    31k_Gemm_BatchNorm_GELU_ReLU
    32k_Gemm_Sigmoid_LogSumExp
    33k_Conv3d_Mish_Tanh
    34k_Conv2d_Activation_BatchNorm
    35k_Gemm_Scaling_Hardtanh_GELU
    36k_Matmul_Sigmoid_Sum
    37k_Matmul_Swish_Scaling
    38k_Matmul_Dropout_Softmax
    39k_Conv2d_GELU_GlobalAvgPool
    40k_Gemm_GroupNorm_Min_BiasAdd
    41k_Gemm_Add_ReLU
    42k_Gemm_Max_Subtract_GELU
    43k_Gemm_BatchNorm_Scaling_Softmax
    44k_Matmul_Divide_GELU
    45k_Gemm_GroupNorm_Swish_Multiply_Swish
    46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp
    47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh
    48k_Matmul_BatchNorm_BiasAdd_Divide_Swish
    49k_Matmul_AvgPool_GELU_Scale_Max
    50k_Matmul_GELU_Softmax
)

echo "============================================"
echo "best_of_n sweep on $AUTOCOMP_TPU_NAME"
echo "Model: $MODEL | Budget: n=144"
echo "============================================"

for k in "${KERNELS[@]}"; do
    OUT="$OUTBASE/best_of_n/$k"
    if [[ -f "$OUT/summary.json" ]]; then
        echo ">>> best_of_n $k  SKIP (summary.json exists)"
        continue
    fi
    echo ">>> best_of_n $k"
    python -m autocomp.baselines.best_of_n \
        --prob_id "$k" \
        --prob_type jaxbench-baseline \
        --n 144 \
        --model "$MODEL" \
        --output_dir "$OUT"
done

echo "============================================"
echo "DONE"
echo "============================================"
