#!/bin/bash
set -e

export VLLM_USE_FLASHINFER=0
export VLLM_USE_V1=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

n_splits=1
setting=semi_supervised
dataset="ecoli"
DATA_DIR="dataset"
EXP_BASE_DIR="exp"
DISPAT_EXP_BASE_DIR="exp_dispat"

TRAIN_GPUS="${TRAIN_GPUS:-0}"
INFERENCE_GPUS="${INFERENCE_GPUS:-0}"
n_permutations=21

MODEL="smol-360"
LR=5e-5
MAX_STEPS=2000
BATCH_SIZE=16
EVAL_BATCH_SIZE=$((BATCH_SIZE * 2))
BETA=0.1
EPSILON=0.02
MAX_ITERATIONS=3
F_DIVERGENCE="identity"
GPU_MEMORY_UTILIZATION=0.7

echo "=========================================="
echo "Running DiSPaT on $dataset"
echo "=========================================="

split_idx=0
exp_dir="$EXP_BASE_DIR/$dataset/$setting/split$n_splits/split$split_idx"
finetuned_model_dir="$exp_dir/models"

if [ ! -d "$finetuned_model_dir" ]; then
    echo "Warning: Finetuned model not found at $finetuned_model_dir. Skipping"
    exit 1
fi

finetuned_model=$(find "$finetuned_model_dir" -maxdepth 1 -type d -name "anollm_*" ! -name "*_data" | head -1)

if [ -z "$finetuned_model" ] || [ ! -f "$finetuned_model/config.json" ]; then
    echo "Warning: Valid model not found in $finetuned_model_dir. Skipping"
    exit 1
fi

iter0_run_name="anollm_lr5e-05_standard_smolLM360_iter0_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
dispat_data_dir="dispat_data/$dataset/split$split_idx/$iter0_run_name"
dispat_exp_dir="$DISPAT_EXP_BASE_DIR/360M/$dataset/split$split_idx"

echo "DiSPaT Data Dir: $dispat_data_dir"
echo "DiSPaT Exp Dir: $dispat_exp_dir"

echo "Iteration 0: Generating samples..."
CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
    --dataset $dataset \
    --setting $setting \
    --data_dir $DATA_DIR \
    --base_model_dir "$finetuned_model" \
    --model $MODEL \
    --binning standard \
    --n_splits $n_splits \
    --split_idx $split_idx \
    --output_dir "$dispat_data_dir" \
    --iteration 0 \
    --n_target_features 4 \
    --use_normal_generation \
    --generation_temperature 1.0 \
    --gpu_memory_utilization $GPU_MEMORY_UTILIZATION

echo "Iteration 0: Training DiSPaT model..."
CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
    --dataset $dataset \
    --setting $setting \
    --data_dir $DATA_DIR \
    --reference_model_dir "$finetuned_model" \
    --spin_data_dir "$dispat_data_dir" \
    --exp_dir "$dispat_exp_dir" \
    --model $MODEL \
    --lr $LR \
    --binning standard \
    --n_splits $n_splits \
    --split_idx $split_idx \
    --beta $BETA \
    --epsilon $EPSILON \
    --f_divergence_type $F_DIVERGENCE \
    --max_steps $MAX_STEPS \
    --batch_size $BATCH_SIZE \
    --eval_n_permutations $n_permutations \
    --eval_batch_size $EVAL_BATCH_SIZE \
    --iteration 0

for ITER in $(seq 1 $MAX_ITERATIONS); do
    PREV_ITER=$((ITER - 1))
    prev_run_name="anollm_lr5e-05_standard_smolLM360_iter${PREV_ITER}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
    prev_model_dir="$dispat_exp_dir/models/$prev_run_name"
    
    if [ ! -f "$prev_model_dir/config.json" ]; then
        echo "Error: Previous iteration model not found at $prev_model_dir"
        break
    fi
    
    echo "Iteration $ITER: Generating samples..."
    CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
        --dataset $dataset \
        --setting $setting \
        --data_dir $DATA_DIR \
        --base_model_dir "$prev_model_dir" \
        --model $MODEL \
        --binning standard \
        --n_splits $n_splits \
        --split_idx $split_idx \
        --output_dir "$dispat_data_dir" \
        --iteration $ITER \
        --n_target_features 4 \
        --use_normal_generation \
        --generation_temperature 1.0 \
        --gpu_memory_utilization $GPU_MEMORY_UTILIZATION
    
    echo "Iteration $ITER: Training DiSPaT model..."
    CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
        --dataset $dataset \
        --setting $setting \
        --data_dir $DATA_DIR \
        --reference_model_dir "$prev_model_dir" \
        --spin_data_dir "$dispat_data_dir" \
        --exp_dir "$dispat_exp_dir" \
        --model $MODEL \
        --lr $LR \
        --binning standard \
        --n_splits $n_splits \
        --split_idx $split_idx \
        --beta $BETA \
        --epsilon $EPSILON \
        --f_divergence_type $F_DIVERGENCE \
        --max_steps $MAX_STEPS \
        --batch_size $BATCH_SIZE \
        --eval_n_permutations $n_permutations \
        --eval_batch_size $EVAL_BATCH_SIZE \
        --iteration $ITER
done

echo "Evaluating DiSPaT model..."
final_iter=$MAX_ITERATIONS
final_run_name="anollm_lr5e-05_standard_smolLM360_iter${final_iter}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
final_model_dir="$dispat_exp_dir/models/$final_run_name"

CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS python DiSPaT/evaluate_dispat.py \
    --dataset $dataset \
    --exp_dir "$dispat_exp_dir" \
    --model $MODEL \
    --binning standard \
    --n_splits $n_splits \
    --split_idx $split_idx \
    --setting $setting \
    --batch_size $EVAL_BATCH_SIZE \
    --n_permutations $n_permutations

python -u evaluate/get_results.py \
    --dataset $dataset \
    --exp_base_dir $DISPAT_EXP_BASE_DIR/360M \
    --n_splits $n_splits \
    --setting $setting | tee "$DISPAT_EXP_BASE_DIR/360M/$dataset/$setting/split$n_splits/evaluate.log"

echo "DiSPaT experiments for $dataset completed!"
