#!/bin/bash

# Model and optimizer configurations
MODEL=bert-base # bert-base OR bert-large OR opt-1.3b
OPTIMIZER=nanoadam

# Task    | Train Size | Steps per Epoch with BS=32(Train)
# --------|------------|-------------------------
# CoLA    | 8,551      | 267
# SST-2   | 67,349     | 2,104
# MRPC    | 3,668      | 115
# STS-B   | 7,000      | 219
# QQP     | 364,292    | 11,401
# MNLI    | 392,702    | 12,263
# QNLI    | 104,743    | 3,274
# RTE     | 2,490      | 78
# WNLI    | 635        | 20
# Define GLUE task dataset sizes
declare -A datasets
datasets=(
      ["cola"]=8551
    # ["sst2"]=67349
    #   ["mrpc"]=3668
    #   ["stsb"]=7000
    #   ["qqp"]=364292
    # ["mnli"]=392702
    #   ["qnli"]=104743
    #   ["rte"]=2490
    #   ["wnli"]=635
)

# Hyperparameters
SEED=42
LRs=(2e-4) #(4e-5 5e-5 6e-5 7e-5 8e-5 9e-5 10e-5 11e-5 12e-5) 7e-6 5e-6 4e-6 3e-6 2e-6 1e-6
BatchSize=32
DENSITY_VALUES=(0.1) #(0.01 0.03 0.05 0.07)
logging_per_epoch=8

# WANDB configurations
WANDB_PROJECT=nanoadam
WANDB_GROUP=huggingface

# Function to run the training script
run_training() {
    local DENSITY="$1"
    local LR="$2"
    local task_name="$3"

    WANDB_JOB_TYPE=glue_"$task_name"
    train_size=${datasets[$task_name]}
    local logging_step=$((train_size / (BatchSize * logging_per_epoch)))
    local mask_interval=$((logging_step / 5))
    local density_interval=$((logging_step))

    # Print the task and logging step
    echo "for $task_name, logging_step = $logging_step, mask_interval = $mask_interval"
    echo "Running with MODEL = ${MODEL}, OPTIMIZER = ${OPTIMIZER}, LR = ${LR}, DENSITY = ${DENSITY}, TASK = ${task_name}"
    local WANDB_NAME="${task_name}_${MODEL}_${OPTIMIZER}_${DENSITY}_LR${LR}_maskintervel${mask_interval}"

    # Create output directory
    local OUTPUT_DIR=./results/hf_glue_${task_name}_${MODEL}_${OPTIMIZER}_${DENSITY}
    mkdir -p "$OUTPUT_DIR"

    # Run the training script
    python glue.py \
        --num_train_epochs 5 \
        --optimizer_name ${OPTIMIZER} \
        --logging_strategy steps \
        --logging_steps ${logging_step} \
        --eval_strategy steps \
        --eval_steps ${logging_step} \
        --task_name ${task_name} \
        --do_train \
        --do_eval \
        --do_predict \
        --max_seq_length 128 \
        --per_device_train_batch_size ${BatchSize} \
        --learning_rate 0.0001 \
        --overwrite_output_dir \
        --save_strategy no \
        --save_total_limit 1 \
        --bf16 \
        --bf16_full_eval \
        \
        --model_name_or_path ${MODEL} \
        --output_dir ${OUTPUT_DIR} \
        --wandb_project ${WANDB_PROJECT} \
        --wandb_group ${WANDB_GROUP} \
        --wandb_job_type ${WANDB_JOB_TYPE} \
        --wandb_name ${WANDB_NAME} \
        \
        --seed ${SEED} \
        --lr ${LR} \
        --k ${DENSITY} \
        \
        --weight_decay 0 \
        --beta1 0.9 \
        --beta2 0.999 \
        --eps 1e-8 \
        --dynamic_log_step ${logging_step} \
        --mask_interval ${mask_interval} \
        --density_interval ${density_interval} \
        --dynamic_density
}

# Loop
for task_name in "${!datasets[@]}"; do
    for DENSITY in "${DENSITY_VALUES[@]}"; do
        for LR in "${LRs[@]}"; do
            run_training "$DENSITY" "$LR" "$task_name"
            wait
        done
    done
done

echo "All training runs completed."
