#!/usr/bin/env bash
set -e

CUDA_VISIBLE_DEVICES=3,4
PORT=22112
OUT=eval_results
ROOT=/mnt/data/shared/shparashar/lift/SFT/sft_output

models="vanilla argmax3 mix_less"
datasets="gsm8k math"
steps="128 256 512"

for m in $models; do
  for d in $datasets; do
    for s in $steps; do
      echo "model=$m dataset=$d steps=$s port=$PORT"
      CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
      torchrun --nproc_per_node 2 --master_port $PORT eval.py \
        --dataset $d --batch_size 8 --gen_length $s --block_length 32 --diffusion_steps $s \
        --output_dir "$OUT" --model_path "$ROOT/$m"
      PORT=$((PORT+1))
    done
  done
done
