#!/bin/bash

# Model and optimizer configurations
MODEL=bert-base # bert-base OR bert-large OR opt-1.3b, OR /p/home/jusers/zhou17/juwels/.cache/huggingface/hub/facebook/opt-1.3b
OPTIMIZER=microadam

# Task    | Train Size | Steps per Epoch with BS=32(Train)
# --------|------------|-------------------------
# CoLA    | 8,551      | 267
# SST-2   | 67,349     | 2,104
# MRPC    | 3,668      | 115
# STS-B   | 7,000      | 219
# QQP     | 364,292    | 11,401
# MNLI    | 392,702    | 12,263
# QNLI    | 104,743    | 3,274
# RTE     | 2,490      | 78
# WNLI    | 635        | 20
# Define GLUE task dataset sizes
declare -A datasets
datasets=(
    ["cola"]=8551
    #   ["sst2"]=67349
    #   ["mrpc"]=3668
    #   ["stsb"]=7000
    #   ["qqp"]=364292
    #   ["mnli"]=392702
    #   ["qnli"]=104743
    #   ["rte"]=2490
    #   ["wnli"]=635
)

# Hyperparameters
SEED=42
QUANT_BLOCK_SIZE=100000
NGRADS=10
#(7e-5 5e-5 3e-5 1e-5 7e-6 5e-6 3e-6 1e-6)
LRs=(4e-5) # use --lr to set learning rate and let learning_rate set to 1e-4 (the last one will be ignored)
BatchSize=32
DENSITY=0.01 # percentage, 0.01 means 1%
logging_per_epoch=8

# WANDB configurations
WANDB_PROJECT=nanoadam
WANDB_GROUP=huggingface

# Function to run the training script
run_training() {
    local LR="$1"
    local task_name="$2"

    local WANDB_JOB_TYPE=glue_"$task_name"
    train_size=${datasets[$task_name]}
    local logging_step=$((train_size / (BatchSize * logging_per_epoch)))

    # Print the task and logging step
    echo "for $task_name, logging_step = $logging_step"
    echo "Running with MODEL = ${MODEL}, OPTIMIZER = ${OPTIMIZER}, DENSITY = ${DENSITY}, TASK = ${task_name}"
    local WANDB_NAME="${task_name}_${MODEL}_${OPTIMIZER}_LR${LR}_${DENSITY}"

    # Create output directory
    local OUTPUT_DIR=./results/hf_glue_${task_name}_${MODEL}_${OPTIMIZER}_${DENSITY}
    mkdir -p "$OUTPUT_DIR"

    # Run the training script
    python glue.py \
        --num_train_epochs 5 \
        --optimizer_name ${OPTIMIZER} \
        --logging_strategy steps \
        --logging_steps ${logging_step} \
        --eval_strategy steps \
        --eval_steps ${logging_step} \
        --task_name ${task_name} \
        --do_train \
        --do_eval \
        --do_predict \
        --max_seq_length 128 \
        --per_device_train_batch_size ${BatchSize} \
        --learning_rate 0.0001 \
        --overwrite_output_dir \
        --save_strategy epoch \
        --save_total_limit 1 \
        --bf16 \
        --bf16_full_eval \
        --save_only_model \
        \
        --model_name_or_path ${MODEL} \
        --output_dir ${OUTPUT_DIR} \
        --wandb_project ${WANDB_PROJECT} \
        --wandb_group ${WANDB_GROUP} \
        --wandb_job_type ${WANDB_JOB_TYPE} \
        --wandb_name ${WANDB_NAME} \
        \
        --seed ${SEED} \
        --lr ${LR} \
        --quant_block_size ${QUANT_BLOCK_SIZE} \
        --ngrads ${NGRADS} \
        --k ${DENSITY} \
        \
        --weight_decay 0 \
        --beta1 0.9 \
        --beta2 0.999 \
        --eps 1e-8 \
        --dynamic_log_step ${logging_step}
}

# Loop over density values
for task_name in "${!datasets[@]}"; do
    for LR in "${LRs[@]}"; do
        run_training "$LR" "$task_name"
        wait
    done
done

echo "All training runs completed."
