#!/bin/bash

export BASE_DATA=YOUR_BASE_DATA_PATH
export VAR_DATA=$BASE_DATA/var_diff

export HF_DATASETS_CACHE=$BASE_DATA/cache_hugg
export HF_HOME=$BASE_DATA/cache_hugg
export HF_HUB_CACHE=$BASE_DATA/cache_hugg
export WANDB_DIR=$BASE_DATA/wandb

export LOGDIR=YOUR_LOGDIR_PATH
export WANDB_PROJECT=var-diff-v2

# Configuration
MODEL_PATH=GSAI-ML/LLaDA-8B-Instruct
GPU_IDS=(0 1 2 3)
NUM_GPUS=${#GPU_IDS[@]}
GPU_LIST=$(IFS=, ; echo "${GPU_IDS[*]}")
MASTER_PORT=29411

# Associative array for task → base checkpoint dir
declare -A BASE_CKPT_DIRS
BASE_CKPT_DIRS["gsm8k"]="${VAR_DATA}/checkpoints/wll_NP_gsm8k"
BASE_CKPT_DIRS["countdown"]="${VAR_DATA}/checkpoints/wll_NP_countdown"
BASE_CKPT_DIRS["math"]="${VAR_DATA}/checkpoints/wll_NP_math"
BASE_CKPT_DIRS["sudoku"]="${VAR_DATA}/checkpoints/d1_us_sudoku"

# Add more if needed

# Associative array for task → checkpoint numbers (as strings)
declare -A CHECKPOINTS
CHECKPOINTS["gsm8k"]=YOUR_CHECKPOINT_NUMBERS_HERE
CHECKPOINTS["math"]=YOUR_CHECKPOINT_NUMBERS_HERE
CHECKPOINTS["countdown"]=YOUR_CHECKPOINT_NUMBERS_HERE
CHECKPOINTS["sudoku"]=YOUR_CHECKPOINT_NUMBERS_HERE


# List of tasks and gen lengths
TASKS=("sudoku" "countdown")
GEN_LENGTHS=(512)

# Loop over tasks
for task in "${TASKS[@]}"; do
  ckpt_base="${BASE_CKPT_DIRS[$task]}"
  ckpt_list=${CHECKPOINTS[$task]}

  for ckpt_num in $ckpt_list; do
    CKPT_PATH="$ckpt_base/checkpoint-$ckpt_num"

    for gen_length in "${GEN_LENGTHS[@]}"; do
      MASTER_PORT=$(shuf -i 1000-2000 -n 1)
      echo "Using MASTER_PORT=$MASTER_PORT"
      # Batch size logic
      if [ "$gen_length" -eq 512 ]; then
        batch_size=8
      else
        batch_size=16
      fi

      echo "Evaluating $task @ checkpoint $ckpt_num, gen_length=$gen_length"

      CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.run \
        --nproc_per_node=4 \
        --master_port=$MASTER_PORT \
        eval/eval.py \
        --dataset $task \
        --batch_size $batch_size \
        --gen_length $gen_length \
        --output_dir "${VAR_DATA}/eval_results" \
        --model_path $MODEL_PATH \
        --checkpoint_path $CKPT_PATH
    done
  done
done


# List of tasks and gen lengths
TASKS=("gsm8k" "math")
GEN_LENGTHS=(256)

# Loop over tasks
for task in "${TASKS[@]}"; do
  ckpt_base="${BASE_CKPT_DIRS[$task]}"
  ckpt_list=${CHECKPOINTS[$task]}

  for ckpt_num in $ckpt_list; do
    CKPT_PATH="$ckpt_base/checkpoint-$ckpt_num"

    for gen_length in "${GEN_LENGTHS[@]}"; do
      MASTER_PORT=$(shuf -i 1000-2000 -n 1)
      echo "Using MASTER_PORT=$MASTER_PORT"
      # Batch size logic
      if [ "$gen_length" -eq 512 ]; then
        batch_size=8
      else
        batch_size=16
      fi

      echo "Evaluating $task @ checkpoint $ckpt_num, gen_length=$gen_length"

      CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.run \
        --nproc_per_node=4 \
        --master_port=$MASTER_PORT \
        eval/eval.py \
        --dataset $task \
        --batch_size $batch_size \
        --gen_length $gen_length \
        --output_dir "${VAR_DATA}/eval_results" \
        --model_path $MODEL_PATH \
        --checkpoint_path $CKPT_PATH
    done
  done
done

echo "All evaluations completed!"

python3 eval/parse_and_get_acc.py --directory "${VAR_DATA}/eval_results"