#!/bin/bash

# Script to generate 32 samples per question on GSM8K dataset
# Uses mini_batch_size=4 and n_itr=8 to avoid OOM (4 * 8 = 32 total samples)
# This will save the samples with --log_samples for later evaluation

# Set PyTorch memory allocator config to reduce fragmentation
#export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

CUDA_VISIBLE_DEVICES=7  accelerate launch eval_llada.py \
    --tasks gsm8k\
    --include_path . \
    --model llada_dist \
    --model_args model_path='GSAI-ML/LLaDA-8B-Base',gen_length=256,parallel_tokens=2,block_length=32,temperature=0.7 \
    --log_samples \
    --output_path ./outputs/gsm8k_eval_$(date +%Y%m%d_%H%M%S)