#!/bin/bash
# Example slurm header (optional). Remove if running locally.
#SBATCH -p ds4ai
#SBATCH -c 16
#SBATCH -t 4:00:00
#SBATCH --mem=32gb
#SBATCH --gres=gpu:1
#SBATCH --output=logs/dllm_%j.out
#SBATCH --error=logs/dllm_%j.err

set -euo pipefail

export HF_HOME=/scratch/tgeng_lab/sun/init_weights
# You can override these via env or inline edits
# openai/gsm8k | HuggingFaceH4/MATH-500 | Maxwell-Jia/AIME_2024
PATH_TO_DATA=${PATH_TO_DATA:-"HuggingFaceH4/MATH-500"}
PATH_TO_MODEL=${PATH_TO_MODEL:-"GSAI-ML/LLaDA-8B-Instruct"}
DLLM_TYPE=${DLLM_TYPE:-"llada"}   # llada | dream

# Common
STEPS=${STEPS:-128}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-128}
TEMPERATURE=${TEMPERATURE:-0.0}

# LLaDA specific
LLADA_BLOCK_LENGTH=${LLADA_BLOCK_LENGTH:-32}
LLADA_CFG_SCALE=${LLADA_CFG_SCALE:-0.0}
LLADA_REMASKING=${LLADA_REMASKING:-low_confidence}
LLADA_MASK_ID=${LLADA_MASK_ID:-126336}

# DREAM specific
DREAM_TOP_P=${DREAM_TOP_P:-0.95}
DREAM_ALG=${DREAM_ALG:-entropy}
DREAM_ALG_TEMP=${DREAM_ALG_TEMP:-0.0}

CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} \
python -u dllm_main.py \
  --dataset "$PATH_TO_DATA" \
  --model_name_or_path "$PATH_TO_MODEL" \
  --dllm_type "$DLLM_TYPE" \
  --output_dir ./output \
  --start_data_idx 0 \
  --end_data_idx 1319 \
  --solver_prompt_idx 1 \
  --steps $STEPS \
  --max_new_tokens $MAX_NEW_TOKENS \
  --temperature $TEMPERATURE \
  --llada_block_length $LLADA_BLOCK_LENGTH \
  --llada_cfg_scale $LLADA_CFG_SCALE \
  --llada_remasking $LLADA_REMASKING \
  --llada_mask_id $LLADA_MASK_ID \
  --dream_top_p $DREAM_TOP_P \
  --dream_alg $DREAM_ALG \
  --dream_alg_temp $DREAM_ALG_TEMP 

