#!/bin/bash

# These are set via environment variables by the SBATCH script
# If they're not set, use default values
NUM_GPUS=${NUM_GPUS:-1}
NUM_NODES=${NUM_NODES:-1}
BATCH_SIZE_PER_GPU=${BATCH_SIZE_PER_GPU:-8}
TOTAL_BATCH_SIZE=${TOTAL_BATCH_SIZE:-128}
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))

# Ensure minimum of 1 step
if [ $GRADIENT_ACC_STEPS -lt 1 ]; then
    GRADIENT_ACC_STEPS=1
    echo "Warning: Calculated gradient accumulation < 1, setting to 1"
    echo "Total effective batch size will be $((NUM_GPUS * BATCH_SIZE_PER_GPU))"
else
    echo "Using gradient accumulation steps: $GRADIENT_ACC_STEPS"
    echo "Total effective batch size will be $((NUM_GPUS * BATCH_SIZE_PER_GPU * GRADIENT_ACC_STEPS))"
fi

# Key hyperparameters for run naming
LORA_RANK=64
LORA_ALPHA=16
LORA_MODULE_FRAC=0.25
# You can now either use a predefined strategy or a comma-separated list of module types
# Examples:
# LORA_SELECT_STRATEGY=mod_type_dec  # Predefined strategy
# LORA_SELECT_STRATEGY=q_proj  # Single module type
LORA_SELECT_STRATEGY=q_proj,o_proj  # Multiple module types
LR=1e-4
EPOCHS=2

# Sort the strategy list alphabetically
LORA_SELECT_STRATEGY=$(echo $LORA_SELECT_STRATEGY | tr ',' '\n' | sort | tr '\n' ',' | sed 's/,$//')

# Names
PROJECT_NAME="metamathqa_lora"
EXP_NAME="llama3.2_1b/r${LORA_RANK}_lr${LR}_${EPOCHS}ep"
RUN_NAME="mf_${LORA_MODULE_FRAC}_strat_${LORA_SELECT_STRATEGY}"

echo "Training model using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

MAIN_PORT=$((29500 + (${SLURM_JOB_ID:-0} % 1000)))
echo "Using main_process_port=$MAIN_PORT"

accelerate launch \
    --main_process_port $MAIN_PORT \
    --mixed_precision bf16 \
    --num_machines $NUM_NODES \
    --num_processes $NUM_GPUS \
    src/train.py \
    --model_name_or_path "meta-llama/Llama-3.2-1B" \
    --use_lora True \
    --lora_module_frac $LORA_MODULE_FRAC \
    --lora_select_strategy $LORA_SELECT_STRATEGY \
    --lora_rank $LORA_RANK \
    --lora_alpha $LORA_ALPHA \
    --lora_dropout 0.0 \
    --dataset_mixer_list "meta-math/MetaMathQA" "1.0" \
    --dataset_mixer_list_splits "train" \
    --dataset_transform_fn "sft_metamathqa_tokenize_and_truncate_v1" "sft_metamathqa_filter_v1" \
    --max_seq_length 1024 \
    --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
    --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
    --learning_rate $LR \
    --lr_scheduler_type "cosine" \
    --warmup_ratio 0.1 \
    --weight_decay 0.0 \
    --num_train_epochs $EPOCHS \
    --clip_grad_norm 1.0 \
    --output_dir "./output/$PROJECT_NAME/$EXP_NAME" \
    --logging_steps 1 \
    --with_tracking \
    --report_to wandb \
    --wandb_project_name $PROJECT_NAME \
    --wandb_entity "nikhil_ghosh" \
    --exp_name $EXP_NAME \
    --run_name $RUN_NAME \
    --checkpointing_steps 500 \
    --tags "mathmathqa_math" \
    --use_flash_attn True