#!/bin/bash

# ---------------- Base Params ----------------
BASE_EPOCHS=30
BASE_BS=64
BASE_LR=1e-4
BASE_MIN_LR=1e-5
BASE_WD=1e-5
BASE_HIDDEN=256
BASE_LAYERS=8
BASE_PROJ=128
BASE_CHIRAL="Kernel"
BASE_REG="qr"

OUTDIR="/home/data/HCT/res/moleculenet/logs"
mkdir -p $OUTDIR

CUDA_DEVICES=(0 1 2 3 4 6)
DEVICE_IDX=0

# ---------------- Check argument ----------------
if [ $# -lt 1 ]; then
    echo "Usage: $0 <dataset>"
    echo "Available datasets: bbbp sider clintox freesolv bace"
    exit 1
fi

DATASET=$1
echo "=== Running parameter search for dataset: $DATASET ==="

# ---------------- Parameter Search Space (per dataset) ----------------
declare -A SEARCH_BS
declare -A SEARCH_LR
declare -A SEARCH_WD
declare -A SEARCH_HIDDEN
declare -A SEARCH_LAYERS
declare -A SEARCH_PROJ

SEARCH_BS[bbbp]="32"
SEARCH_LR[bbbp]="1e-3 3e-4"
SEARCH_WD[bbbp]="1e-4"
SEARCH_HIDDEN[bbbp]="128"
SEARCH_LAYERS[bbbp]="4"
SEARCH_PROJ[bbbp]="256"
# 2039 molecules

SEARCH_BS[sider]="32"
SEARCH_LR[sider]="3e-4"
SEARCH_WD[sider]="1e-4"
SEARCH_HIDDEN[sider]="128"
SEARCH_LAYERS[sider]="4"
SEARCH_PROJ[sider]="256"
# 1427 molecules

SEARCH_BS[clintox]="32"
SEARCH_LR[clintox]="1e-3 3e-4"
SEARCH_WD[clintox]="1e-4"
SEARCH_HIDDEN[clintox]="128"
SEARCH_LAYERS[clintox]="4"
SEARCH_PROJ[clintox]="256"
# 1478 molecules

SEARCH_BS[freesolv]="16"
SEARCH_LR[freesolv]="1e-3 3e-4"
SEARCH_WD[freesolv]="0 1e-4"
SEARCH_HIDDEN[freesolv]="128"
SEARCH_LAYERS[freesolv]="4 12"
SEARCH_PROJ[freesolv]="256"
# 642 molecules

SEARCH_BS[bace]="32"
SEARCH_LR[bace]="1e-3 3e-4"
SEARCH_WD[bace]="0 1e-4"
SEARCH_HIDDEN[bace]="128"
SEARCH_LAYERS[bace]="4"
SEARCH_PROJ[bace]="256"
# 1513 molecules


# ---------------- Launch Experiments ----------------
run_exp() {
    local dataset=$1
    local bs=$2
    local lr=$3
    local wd=$4
    local hidden=$5
    local layers=$6
    local proj=$7

    # Skip: avoid running the all-base case
    if [[ $bs == $BASE_BS &&
          $lr == $BASE_LR &&
          $wd == $BASE_WD &&
          $hidden == $BASE_HIDDEN &&
          $layers == $BASE_LAYERS &&
          $proj == $BASE_PROJ ]]; then
        echo "[Skip] $dataset uses all base parameters. (Avoid duplicate)"
        return
    fi

    min_lr=$(python3 - <<EOF
print($lr * 0.1)
EOF
)

    prefix="${dataset}_ep${BASE_EPOCHS}_bs${bs}_lr${lr}_wd${wd}_h${hidden}_L${layers}_p${proj}"

    echo "Launching: $prefix on cuda:${CUDA_DEVICES[$DEVICE_IDX]}"

    python3 main_moleculenet.py \
        --epochs $BASE_EPOCHS --bs $bs --lr $lr --min_lr $min_lr \
        --weight_decay $wd --hidden_dim $hidden --num_layers $layers \
        --proj_dim $proj --chiral_encoder $BASE_CHIRAL --use_${BASE_REG} \
        --dataset $dataset \
        --device "cuda:${CUDA_DEVICES[$DEVICE_IDX]}" \
        > $OUTDIR/${prefix}.log 2>&1 &

    DEVICE_IDX=$(( (DEVICE_IDX + 1) % ${#CUDA_DEVICES[@]} ))
}


# ---------------- Launch experiments for the specified dataset ----------------
dataset=$DATASET

# for dataset in bbbp sider clintox freesolv bace; do
echo "==== Searching params for $dataset ===="

# ====== 0. Run BASELINE once ======
echo "[Baseline] Launching base config for $dataset"

min_lr_base=$(python3 - <<EOF
print($BASE_LR * 0.1)
EOF
)

prefix_base="${dataset}_BASE_ep${BASE_EPOCHS}_bs${BASE_BS}_lr${BASE_LR}_wd${BASE_WD}_h${BASE_HIDDEN}_L${BASE_LAYERS}_p${BASE_PROJ}"
echo "Launching: $prefix_base on cuda:${CUDA_DEVICES[$DEVICE_IDX]}"

python3 main_moleculenet.py \
    --epochs $BASE_EPOCHS --bs $BASE_BS --lr $BASE_LR --min_lr $min_lr_base \
    --weight_decay $BASE_WD --hidden_dim $BASE_HIDDEN --num_layers $BASE_LAYERS \
    --proj_dim $BASE_PROJ --chiral_encoder $BASE_CHIRAL --use_${BASE_REG} \
    --dataset $dataset \
    --device "cuda:${CUDA_DEVICES[$DEVICE_IDX]}" \
    > $OUTDIR/${prefix_base}.log 2>&1 &

DEVICE_IDX=$(( (DEVICE_IDX + 1) % ${#CUDA_DEVICES[@]} ))

# ====== 1–6. Parameter search: BS, LR, WD, HIDDEN, LAYERS, PROJ ======

# 1. Batch Size
for bs in ${SEARCH_BS[$dataset]}; do
    run_exp $dataset $bs $BASE_LR $BASE_WD $BASE_HIDDEN $BASE_LAYERS $BASE_PROJ
done

# 2. LR
for lr in ${SEARCH_LR[$dataset]}; do
    run_exp $dataset $BASE_BS $lr $BASE_WD $BASE_HIDDEN $BASE_LAYERS $BASE_PROJ
done

# 3. Weight Decay
for wd in ${SEARCH_WD[$dataset]}; do
    run_exp $dataset $BASE_BS $BASE_LR $wd $BASE_HIDDEN $BASE_LAYERS $BASE_PROJ
done

# 4. Hidden dim
for hidden in ${SEARCH_HIDDEN[$dataset]}; do
    run_exp $dataset $BASE_BS $BASE_LR $BASE_WD $hidden $BASE_LAYERS $BASE_PROJ
done

# 5. Layers
for layers in ${SEARCH_LAYERS[$dataset]}; do
    run_exp $dataset $BASE_BS $BASE_LR $BASE_WD $BASE_HIDDEN $layers $BASE_PROJ
done

# 6. Proj dim
for proj in ${SEARCH_PROJ[$dataset]}; do
    run_exp $dataset $BASE_BS $BASE_LR $BASE_WD $BASE_HIDDEN $BASE_LAYERS $proj
done
# done

wait
echo "All jobs launched for $dataset!"