#!/bin/bash

# These are set via environment variables by the SBATCH script
# If they're not set, use default values
NUM_GPUS=${NUM_GPUS:-1}
BATCH_SIZE_PER_GPU=${BATCH_SIZE_PER_GPU:-8}

TOTAL_BATCH_SIZE=${TOTAL_BATCH_SIZE:-128}
step_batch=$((NUM_GPUS * BATCH_SIZE_PER_GPU))

GRADIENT_ACC_STEPS=$(( (TOTAL_BATCH_SIZE + step_batch - 1) / step_batch ))
if [ "$GRADIENT_ACC_STEPS" -lt 1 ]; then GRADIENT_ACC_STEPS=1; fi

echo "→ NUM_GPUS=$NUM_GPUS  BATCH_SIZE_PER_GPU=$BATCH_SIZE_PER_GPU"
echo "→ TARGET_TOTAL_BATCH_SIZE=$TOTAL_BATCH_SIZE"
echo "→ Using GRADIENT_ACC_STEPS=$GRADIENT_ACC_STEPS  (effective batch = $((step_batch*GRADIENT_ACC_STEPS)))"

# Key hyperparameters for run naming
LORA_RANK=64
LORA_ALPHA=16

LORA_SELECT_STRATEGY=${LORA_SELECT_STRATEGY:-k_proj,o_proj}
LR=${LR:-1e-4}

EPOCHS=${EPOCHS:-2}

# Sort the strategy list alphabetically
LORA_SELECT_STRATEGY=$(echo $LORA_SELECT_STRATEGY | tr ',' '\n' | sort | tr '\n' ',' | sed 's/,$//')

# Create a run-name-friendly version by replacing commas with +
LORA_STRAT_NAME=$(echo $LORA_SELECT_STRATEGY | tr ',' '+')

# Names
PROJECT_NAME="metamathqa_lora"
EXP_NAME="llama3.2_3b/r${LORA_RANK}_lr${LR}_${EPOCHS}ep"
RUN_NAME="strat_${LORA_STRAT_NAME}"

echo "Training model using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

MAIN_PORT=$((29500 + (${SLURM_JOB_ID:-0} % 1000)))
echo "Using main_process_port=$MAIN_PORT"

accelerate launch \
    --main_process_port $MAIN_PORT \
    --mixed_precision bf16 \
    --num_processes $NUM_GPUS \
    src/train.py \
    --model_name_or_path "meta-llama/Llama-3.2-3B" \
    --use_lora True \
    --lora_select_strategy $LORA_SELECT_STRATEGY \
    --lora_rank $LORA_RANK \
    --lora_alpha $LORA_ALPHA \
    --lora_dropout 0.0 \
    --dataset_mixer_list "meta-math/MetaMathQA" "1.0" \
    --dataset_mixer_list_splits "train" \
    --dataset_transform_fn "sft_metamathqa_tokenize_and_truncate_v1" "sft_metamathqa_filter_v1" \
    --max_seq_length 1024 \
    --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
    --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
    --learning_rate $LR \
    --lr_scheduler_type "cosine" \
    --warmup_ratio 0.1 \
    --weight_decay 0.0 \
    --num_train_epochs $EPOCHS \
    --clip_grad_norm 1.0 \
    --output_dir "./output/$PROJECT_NAME/$EXP_NAME" \
    --logging_steps 1 \
    --with_tracking \
    --report_to wandb \
    --wandb_project_name $PROJECT_NAME \
    --wandb_entity "nikhil_ghosh" \
    --exp_name $EXP_NAME \
    --run_name $RUN_NAME \
    --checkpointing_steps 500 \
    --tags "mathmathqa_math" \
    --use_flash_attn True