#!/bin/bash
set -e

export VLLM_USE_FLASHINFER=0
export VLLM_USE_V1=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

n_splits=5
setting=semi_supervised
dataset="vifd"
DATA_DIR="dataset"
EXP_BASE_DIR="exp"
DISPAT_EXP_BASE_DIR="exp_dispat"

TRAIN_GPUS="0"
INFERENCE_GPUS="0"
n_permutations=21

MODEL="smol-360"
LR=5e-5
MAX_STEPS=2000
BATCH_SIZE=16
EVAL_BATCH_SIZE=$((BATCH_SIZE * 2))
BETA=0.1
EPSILON=0.02
MAX_ITERATIONS=3
F_DIVERGENCE_TYPES=("identity" "kl" "reverse_kl" "squared_hellinger")
GPU_MEMORY_UTILIZATION=0.7

echo "=========================================="
echo "Running DiSPaT f-Divergence Experiments on $dataset"
echo "Testing f-divergence types: ${F_DIVERGENCE_TYPES[@]}"
echo "=========================================="

for F_DIVERGENCE in "${F_DIVERGENCE_TYPES[@]}"; do
    echo ""
    echo "=========================================="
    echo "Testing f-divergence = $F_DIVERGENCE"
    echo "=========================================="
    
    for ((split_idx = 0 ; split_idx < $n_splits ; split_idx++ )); do
        exp_dir="$EXP_BASE_DIR/$dataset/$setting/split$n_splits/split$split_idx"
        finetuned_model_dir="$exp_dir/models"
        
        if [ ! -d "$finetuned_model_dir" ]; then
            echo "Warning: Finetuned model not found at $finetuned_model_dir. Skipping split $split_idx"
            continue
        fi
        
        finetuned_model=$(find "$finetuned_model_dir" -maxdepth 1 -type d -name "anollm_*" ! -name "*_data" | head -1)
        
        if [ -z "$finetuned_model" ] || [ ! -f "$finetuned_model/config.json" ]; then
            echo "Warning: Valid model not found in $finetuned_model_dir. Skipping split $split_idx"
            continue
        fi
        
        iter0_run_name="anollm_lr5e-05_standard_smolLM360_iter0_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
        dispat_data_dir="dispat_data/$dataset/split$split_idx/$iter0_run_name"
        dispat_exp_dir="$DISPAT_EXP_BASE_DIR/360M/fdiv_${F_DIVERGENCE}/$dataset/split$split_idx"
        
        echo "DiSPaT Data Dir: $dispat_data_dir"
        echo "DiSPaT Exp Dir: $dispat_exp_dir"
        
        echo "Iteration 0: Generating samples..."
        CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
            --dataset $dataset \
            --setting $setting \
            --data_dir $DATA_DIR \
            --base_model_dir "$finetuned_model" \
            --model $MODEL \
            --binning standard \
            --n_splits $n_splits \
            --split_idx $split_idx \
            --output_dir "$dispat_data_dir" \
            --iteration 0 \
            --n_target_features 4 \
            --use_normal_generation \
            --generation_temperature 1.0 \
            --gpu_memory_utilization $GPU_MEMORY_UTILIZATION
        
        echo "Iteration 0: Training DiSPaT model..."
        CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
            --dataset $dataset \
            --setting $setting \
            --data_dir $DATA_DIR \
            --reference_model_dir "$finetuned_model" \
            --spin_data_dir "$dispat_data_dir" \
            --exp_dir "$dispat_exp_dir" \
            --model $MODEL \
            --lr $LR \
            --binning standard \
            --n_splits $n_splits \
            --split_idx $split_idx \
            --beta $BETA \
            --epsilon $EPSILON \
            --f_divergence_type $F_DIVERGENCE \
            --max_steps $MAX_STEPS \
            --batch_size $BATCH_SIZE \
            --eval_n_permutations $n_permutations \
            --eval_batch_size $EVAL_BATCH_SIZE \
            --iteration 0
        
        for ITER in $(seq 1 $MAX_ITERATIONS); do
            PREV_ITER=$((ITER - 1))
            prev_run_name="anollm_lr5e-05_standard_smolLM360_iter${PREV_ITER}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
            prev_model_dir="$dispat_exp_dir/models/$prev_run_name"
            
            if [ ! -f "$prev_model_dir/config.json" ]; then
                echo "Error: Previous iteration model not found at $prev_model_dir"
                break
            fi
            
            echo "Iteration $ITER: Generating samples..."
            CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
                --dataset $dataset \
                --setting $setting \
                --data_dir $DATA_DIR \
                --base_model_dir "$prev_model_dir" \
                --model $MODEL \
                --binning standard \
                --n_splits $n_splits \
                --split_idx $split_idx \
                --output_dir "$dispat_data_dir" \
                --iteration $ITER \
                --n_target_features 4 \
                --use_normal_generation \
                --generation_temperature 1.0 \
                --gpu_memory_utilization $GPU_MEMORY_UTILIZATION
            
            echo "Iteration $ITER: Training DiSPaT model..."
            CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
                --dataset $dataset \
                --setting $setting \
                --data_dir $DATA_DIR \
                --reference_model_dir "$prev_model_dir" \
                --spin_data_dir "$dispat_data_dir" \
                --exp_dir "$dispat_exp_dir" \
                --model $MODEL \
                --lr $LR \
                --binning standard \
                --n_splits $n_splits \
                --split_idx $split_idx \
                --beta $BETA \
                --epsilon $EPSILON \
                --f_divergence_type $F_DIVERGENCE \
                --max_steps $MAX_STEPS \
                --batch_size $BATCH_SIZE \
                --eval_n_permutations $n_permutations \
                --eval_batch_size $EVAL_BATCH_SIZE \
                --iteration $ITER
        done
        
        echo "Evaluating DiSPaT model..."
        final_iter=$MAX_ITERATIONS
        final_run_name="anollm_lr5e-05_standard_smolLM360_iter${final_iter}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
        final_model_dir="$dispat_exp_dir/models/$final_run_name"
        
        CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS python DiSPaT/evaluate_dispat.py \
            --dataset $dataset \
            --exp_dir "$dispat_exp_dir" \
            --model $MODEL \
            --binning standard \
            --n_splits $n_splits \
            --split_idx $split_idx \
            --setting $setting \
            --batch_size $EVAL_BATCH_SIZE \
            --n_permutations $n_permutations
    done
    
    python -u evaluate/get_results.py \
        --dataset $dataset \
        --exp_base_dir $DISPAT_EXP_BASE_DIR/360M/fdiv_${F_DIVERGENCE} \
        --n_splits $n_splits \
        --setting $setting | tee "$DISPAT_EXP_BASE_DIR/360M/fdiv_${F_DIVERGENCE}/$dataset/$setting/split$n_splits/evaluate.log"
done

echo "f-Divergence experiments for $dataset completed!"
