#!/bin/bash

#SBATCH --job-name=dLLM_RL
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=10
#SBATCH --mem=200G
#SBATCH --time=48:00:00
#SBATCH --partition=gpu_sxm
#SBATCH --qos=sxm
#SBATCH --gres=gpu:h100:4
#SBATCH --account=a_eecs_ds
#SBATCH -o ./logs/%A.output
#SBATCH -e ./logs/%A.error

module --ignore_cache load gcc/12.3.0
module --ignore_cache load cuda/12.2.0



export BASE_DATA="${BASE_DATA:-data}"
echo "Saving to $BASE_DATA"

export VAR_DATA=$BASE_DATA/var_diff

export HF_DATASETS_CACHE=$BASE_DATA/cache_hugg
export HF_HOME=$BASE_DATA/cache_hugg
export HF_HUB_CACHE=$BASE_DATA/cache_hugg
export WANDB_DIR=$BASE_DATA/wandb

export LOGDIR=$BASE_DATA/var_diff/logs
mkdir -p $LOGDIR

export WANDB_PROJECT=b1

MODEL_NAME=GSAI-ML/LLaDA-8B-Instruct
DATASET="countdown" 
RUN_NAME=wll_NP_${DATASET}
NUM_ITER=12 # number of policy gradient inner updates iterations
RL_RUN_NAME=${RUN_NAME}
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
    --config_file "b1/accelerate.yaml" \
    --num_processes 4 \
    --main_process_port 12349 b1/run_train.py \
    --config "b1/train.yaml" \
    --model_path $MODEL_NAME \
    --num_iterations $NUM_ITER \
    --dataset $DATASET \
    --trainer_type wll_d1_neg \
    --run_name $RL_RUN_NAME \
    --wandb_project $WANDB_PROJECT \
    --output_dir checkpoints/${RL_RUN_NAME} \
    2>&1 | tee -a $LOGDIR/$RUN_NAME.log