#!/bin/bash

# Parameter ranges
LR_VALUES=(1e-4 )
BATCH_SIZES=(2 4)
MLP_MULT_VALUES=(0.1 1 100)

# Log file
LOG_FILE="param_sweep_log.txt"
echo "Starting parameter sweep at $(date)" > $LOG_FILE

for lr in "${LR_VALUES[@]}"; do
  for bs in "${BATCH_SIZES[@]}"; do
    for mlp_mult in "${MLP_MULT_VALUES[@]}"; do
      
      # Create unique job name
      JOB_NAME="yield_ft_lr${lr}_bs${bs}_mlp${mlp_mult}"
      
      cat << EOF > temp_run.sh
#!/bin/bash
#SBATCH -p AI4Phys
#SBATCH --job-name=${JOB_NAME}
#SBATCH --output=${JOB_NAME}.log
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:2

source activate llama

# Set unique master port for each job
MASTER_ADDR=\$(scontrol show hostname \$SLURM_JOB_NODELIST | head -n1)
MASTER_PORT=\$((RANDOM % 10000 + 20000))

export MASTER_ADDR=\$MASTER_ADDR
export MASTER_PORT=\$MASTER_PORT

echo "Using MASTER_ADDR: \$MASTER_ADDR"
echo "Using MASTER_PORT: \$MASTER_PORT"

srun deepspeed --launcher SLURM \
  --master_addr \$MASTER_ADDR \
  --master_port \$MASTER_PORT \
  yield_ft_ds.py \
  --pretrained_model_path '/mnt/hwfile/ai4chem/share/step1_llama3_8b_0916_yearly_pistachio_ep3' \
  --lora_adapter_path "/mnt/hwfile/ai4chem/chenjianpeng/train_regression/llama_ep3_1115-18/lora_adapter" \
  --yield_predictor_path "/mnt/hwfile/ai4chem/chenjianpeng/train_regression/llama_ep3_1115-18/predictor.pt" \
  --num_epoch 200 \
  --lr ${lr} \
  --data_path '/mnt/petrelfs/handong/llama/train_regression/data4regression' \
  --data_name 'suzuki_miyaura_fg_changes_60' \
  --per_device_train_batch_size ${bs} \
  --save_root '/mnt/hwfile/ai4chem/handong/suzuki_miyaura_fg_changes_60' \
  --gradient_accumulation_steps 1 \
  --use_lora 1 \
  --log_file "${JOB_NAME}_ds.log" \
  --deepspeed_config 'yield_ft_ds_config.json' \
  --mlp_lr_multiplier ${mlp_mult} \
  --project_name 'grid_search_suzuki_miyaura_fg_changes_60'
EOF
      
      # Submit job
      JOB_ID=$(sbatch temp_run.sh | cut -d ' ' -f4)
      
      echo "${JOB_NAME}: JobID=${JOB_ID}, lr=${lr}, bs=${bs}, mlp_mult=${mlp_mult}" >> $LOG_FILE
      
      # Wait between submissions
      echo "Submitted job ${JOB_NAME}, waiting 60 seconds..."
      sleep 60
    done
  done
done

rm temp_run.sh
echo "Parameter sweep completed at $(date)" >> $LOG_FILE