#!/bin/bash
# Example slurm header (optional). Remove if running locally.
#SBATCH -p ds4ai
#SBATCH -c 16
#SBATCH -t 24:00:00
#SBATCH --mem=48gb
#SBATCH --gres=gpu:1
#SBATCH --output=logs/dllm_lseek_%j.out
#SBATCH --error=logs/dllm_lseek_%j.err

set -euo pipefail

export HF_HOME=${HF_HOME:-/scratch/tgeng_lab/sun/init_weights}

# ---- Core params ----
PATH_TO_DATA=${PATH_TO_DATA:-"openai/gsm8k"}
PATH_TO_MODEL=${PATH_TO_MODEL:-"GSAI-ML/LLaDA-8B-Instruct"}
DLLM_TYPE=${DLLM_TYPE:-"llada"}   # llada | dream

# ---- Original generation ----
STEPS=${STEPS:-128}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-128}
TEMPERATURE=${TEMPERATURE:-0.0}

# LLaDA specific
LLADA_BLOCK_LENGTH=${LLADA_BLOCK_LENGTH:-32}
LLADA_CFG_SCALE=${LLADA_CFG_SCALE:-0.0}
LLADA_REMASKING=${LLADA_REMASKING:-low_confidence}
LLADA_MASK_ID=${LLADA_MASK_ID:-126336}

# DREAM specific
DREAM_TOP_P=${DREAM_TOP_P:-0.95}
DREAM_ALG=${DREAM_ALG:-entropy}
DREAM_ALG_TEMP=${DREAM_ALG_TEMP:-0.0}

# ---- Latent-Seek params ----
LR=${LR:-0.03}
K=${K:-0.1}
MAX_NUM_STEPS=${MAX_NUM_STEPS:-10}
REWARD_THRESHOLD=${REWARD_THRESHOLD:--0.2}
START_INDEX_IN_ANSWER=${START_INDEX_IN_ANSWER:-0}

# Tail gen overrides
TAIL_STEPS=${TAIL_STEPS:-128}
TAIL_BLOCK_LENGTH=${TAIL_BLOCK_LENGTH:-32}

# Reward selection
REWARD_TYPE=${REWARD_TYPE:-self_confidence}   # dllm_verifier | self_confidence
CONF_MEASURE=${CONF_MEASURE:-top1}            # top1 | gap
CONF_AGGREGATOR=${CONF_AGGREGATOR:-mean}      # mean | min

CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} \
python -u dllm_lseek_main.py \
  --dataset "$PATH_TO_DATA" \
  --model_name_or_path "$PATH_TO_MODEL" \
  --dllm_type "$DLLM_TYPE" \
  --output_dir ./output \
  --start_data_idx 0 \
  --end_data_idx 20 \
  --solver_prompt_idx 0 \
  --steps $STEPS \
  --max_new_tokens $MAX_NEW_TOKENS \
  --temperature $TEMPERATURE \
  --llada_block_length $LLADA_BLOCK_LENGTH \
  --llada_cfg_scale $LLADA_CFG_SCALE \
  --llada_remasking $LLADA_REMASKING \
  --llada_mask_id $LLADA_MASK_ID \
  --dream_top_p $DREAM_TOP_P \
  --dream_alg $DREAM_ALG \
  --dream_alg_temp $DREAM_ALG_TEMP \
  --lr $LR \
  --k $K \
  --max_num_steps $MAX_NUM_STEPS \
  --reward_threshold $REWARD_THRESHOLD \
  --start_index_in_answer $START_INDEX_IN_ANSWER \
  --tail_steps $TAIL_STEPS \
  --tail_block_length $TAIL_BLOCK_LENGTH \
  --reward_type $REWARD_TYPE \
  --conf_measure $CONF_MEASURE \
  --conf_aggregator $CONF_AGGREGATOR
