#!/bin/bash
# Weighted likelihood using d1 likelihood calcuation + clippling + possibly pi_ref
export BASE_DATA="${BASE_DATA:-/home//d1/data}"
echo "Saving to $BASE_DATA"

export VAR_DATA=$BASE_DATA/var_diff

export HF_DATASETS_CACHE=$BASE_DATA/cache_hugg
export HF_HOME=$BASE_DATA/cache_hugg
export HF_HUB_CACHE=$BASE_DATA/cache_hugg
export WANDB_DIR=$BASE_DATA/wandb

export LOGDIR=$BASE_DATA/var_diff/logs
mkdir -p $LOGDIR


export WANDB_PROJECT=var-diff-v2

MODEL_NAME=GSAI-ML/LLaDA-8B-Instruct
DATASET="gsm8k"
RUN_NAME=wll_P_${DATASET}
NUM_ITER=12 # number of policy gradient inner updates iterations
RL_RUN_NAME=${RUN_NAME}
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup accelerate launch \
    --config_file "wd1/accelerate.yaml" \
    --num_processes 4 \
    --main_process_port 12349 wd1/run_train.py \
    --config "wd1/train.yaml" \
    --model_path $MODEL_NAME \
    --num_iterations $NUM_ITER \
    --dataset $DATASET \
    --trainer_type wll_d1_pos_only \
    --run_name $RL_RUN_NAME \
    --max_steps 8000 \
    --wandb_project $WANDB_PROJECT \
    --output_dir $VAR_DATA/checkpoints/${RL_RUN_NAME} > $LOGDIR/$RUN_NAME.log 2>&1 &