#!/bin/bash
set -e

export VLLM_USE_FLASHINFER=0
export VLLM_USE_V1=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

n_splits=1
setting=semi_supervised
DATA_DIR="dataset"
EXP_BASE_DIR="exp"
DISPAT_EXP_BASE_DIR="exp_dispat"

TRAIN_GPUS="${TRAIN_GPUS:-0}"
INFERENCE_GPUS="${INFERENCE_GPUS:-0}"
n_permutations=21

MODEL="smol-360"
LR=5e-5
MAX_STEPS=2000
BATCH_SIZE=16
EVAL_BATCH_SIZE=$((BATCH_SIZE * 2))
BETA=0.1
EPSILON=0.02
MAX_ITERATIONS=3
F_DIVERGENCE="identity"
GPU_MEMORY_UTILIZATION=0.7

DATASETS=("20news-0" "20news-1" "20news-2" "20news-3" "20news-4" "20news-5")

echo "=========================================="
echo "Running DiSPaT on 20news datasets"
echo "=========================================="

for dataset in "${DATASETS[@]}"; do
    echo ""
    echo "=========================================="
    echo "Running DiSPaT on $dataset"
    echo "=========================================="
    
    split_idx=0
    exp_dir="$EXP_BASE_DIR/$dataset/$setting/split$n_splits/split$split_idx"
    finetuned_model_dir="$exp_dir/models"
    
    if [ ! -d "$finetuned_model_dir" ]; then
        echo "Warning: Finetuned model not found at $finetuned_model_dir. Skipping $dataset"
        continue
    fi
    
    finetuned_model=$(find "$finetuned_model_dir" -maxdepth 1 -type d -name "anollm_*" ! -name "*_data" | head -1)
    
    if [ -z "$finetuned_model" ] || [ ! -f "$finetuned_model/config.json" ]; then
        echo "Warning: Valid model not found in $finetuned_model_dir. Skipping $dataset"
        continue
    fi
    
    iter0_run_name="anollm_lr5e-05_standard_smolLM360_iter0_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
    dispat_data_dir="dispat_data/$dataset/split$split_idx/$iter0_run_name"
    dispat_exp_dir="$DISPAT_EXP_BASE_DIR/360M/$dataset/split$split_idx"
    
    echo "DiSPaT Data Dir: $dispat_data_dir"
    echo "DiSPaT Exp Dir: $dispat_exp_dir"
    
    echo "Iteration 0: Generating samples..."
    CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
        --dataset $dataset \
        --setting $setting \
        --data_dir $DATA_DIR \
        --base_model_dir "$finetuned_model" \
        --model $MODEL \
        --binning standard \
        --n_splits $n_splits \
        --split_idx $split_idx \
        --output_dir "$dispat_data_dir" \
        --iteration 0 \
        --n_target_features 4 \
        --use_normal_generation \
        --generation_temperature 1.0 \
        --gpu_memory_utilization $GPU_MEMORY_UTILIZATION
    
    echo "Iteration 0: Training DiSPaT model..."
    CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
        --dataset $dataset \
        --setting $setting \
        --data_dir $DATA_DIR \
        --reference_model_dir "$finetuned_model" \
        --spin_data_dir "$dispat_data_dir" \
        --exp_dir "$dispat_exp_dir" \
        --model $MODEL \
        --lr $LR \
        --binning standard \
        --n_splits $n_splits \
        --split_idx $split_idx \
        --beta $BETA \
        --epsilon $EPSILON \
        --f_divergence_type $F_DIVERGENCE \
        --max_steps $MAX_STEPS \
        --batch_size $BATCH_SIZE \
        --eval_n_permutations $n_permutations \
        --eval_batch_size $EVAL_BATCH_SIZE \
        --iteration 0
    
    for ITER in $(seq 1 $MAX_ITERATIONS); do
        PREV_ITER=$((ITER - 1))
        prev_run_name="anollm_lr5e-05_standard_smolLM360_iter${PREV_ITER}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
        prev_model_dir="$dispat_exp_dir/models/$prev_run_name"
        
        if [ ! -f "$prev_model_dir/config.json" ]; then
            echo "Error: Previous iteration model not found at $prev_model_dir"
            break
        fi
        
        echo "Iteration $ITER: Generating samples..."
        CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/generate_dispat_samples_vllm.py \
            --dataset $dataset \
            --setting $setting \
            --data_dir $DATA_DIR \
            --base_model_dir "$prev_model_dir" \
            --model $MODEL \
            --binning standard \
            --n_splits $n_splits \
            --split_idx $split_idx \
            --output_dir "$dispat_data_dir" \
            --iteration $ITER \
            --n_target_features 4 \
            --use_normal_generation \
            --generation_temperature 1.0 \
            --gpu_memory_utilization $GPU_MEMORY_UTILIZATION
        
        echo "Iteration $ITER: Training DiSPaT model..."
        CUDA_VISIBLE_DEVICES=$TRAIN_GPUS python DiSPaT/train_dispat.py \
            --dataset $dataset \
            --setting $setting \
            --data_dir $DATA_DIR \
            --reference_model_dir "$prev_model_dir" \
            --spin_data_dir "$dispat_data_dir" \
            --exp_dir "$dispat_exp_dir" \
            --model $MODEL \
            --lr $LR \
            --binning standard \
            --n_splits $n_splits \
            --split_idx $split_idx \
            --beta $BETA \
            --epsilon $EPSILON \
            --f_divergence_type $F_DIVERGENCE \
            --max_steps $MAX_STEPS \
            --batch_size $BATCH_SIZE \
            --eval_n_permutations $n_permutations \
            --eval_batch_size $EVAL_BATCH_SIZE \
            --iteration $ITER
    done
    
    echo "Evaluating DiSPaT model..."
    final_iter=$MAX_ITERATIONS
    final_run_name="anollm_lr5e-05_standard_smolLM360_iter${final_iter}_steps${MAX_STEPS}_eps${EPSILON}_beta${BETA}_f${F_DIVERGENCE}_spin"
    final_model_dir="$dispat_exp_dir/models/$final_run_name"
    
    CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS python DiSPaT/evaluate_dispat.py \
        --dataset $dataset \
        --exp_dir "$dispat_exp_dir" \
        --model $MODEL \
        --binning standard \
        --n_splits $n_splits \
        --split_idx $split_idx \
        --setting $setting \
        --batch_size $EVAL_BATCH_SIZE \
        --n_permutations $n_permutations
    
    python -u evaluate/get_results.py \
        --dataset $dataset \
        --exp_base_dir $DISPAT_EXP_BASE_DIR/360M \
        --n_splits $n_splits \
        --setting $setting | tee "$DISPAT_EXP_BASE_DIR/360M/$dataset/$setting/split$n_splits/evaluate.log"
done

echo "DiSPaT experiments for 20news datasets completed!"
