#!/bin/bash

echo "Starting 5-seed training for UMSE model on fold 1..."

# Define experiment parameters
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=1                      # GPU device number
MATCHED=false
USE_DEMOGRAPHICS=false
USE_TRIPLET=true          # Enable triplet format
CROSS_EVAL=""              # Set to "matched_to_full" or "full_to_matched" if needed

# UMSE-specific parameters
D_MODEL=256                # Model dimension
VARIABLES_NUM=49           # Number of variables (from config)
NUM_LAYERS=1               # Number of transformer layers
NUM_HEADS=4                # Number of attention heads  
N_MODALITY=2               # Number of modalities (EHR, CXR)
BOTTLENECKS_N=1            # Number of bottlenecks for MBT
MAX_EHR_LEN=500            # Maximum EHR sequence length
DROPOUT=0.2                # Dropout rate

# Seeds configuration
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("phenotype")

# Function to run training for a specific task
run_task_training() {
    local TASK=$1
    echo "Starting training for task: $TASK"
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Training with seed $SEED for task $TASK on fold $FOLD..."
        
        # Build base command
        CMD="python ../main.py \
            --model umse \
            --mode train \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --dropout $DROPOUT \
            --seed $SEED \
            --d_model $D_MODEL \
            --variables_num $VARIABLES_NUM \
            --num_layers $NUM_LAYERS \
            --num_heads $NUM_HEADS \
            --n_modality $N_MODALITY \
            --bottlenecks_n $BOTTLENECKS_N \
            --max_ehr_len $MAX_EHR_LEN \
            --num_classes $NUM_CLASSES \
            --gpu $GPU"
        
        # Add conditional parameters       
        if [ "$MATCHED" = "true" ]; then
            CMD="$CMD --matched"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        if [ "$USE_TRIPLET" = "true" ]; then
            CMD="$CMD --use_triplet"
        fi
        
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        fi
        
        echo "Running command: $CMD"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK completed successfully!"
        else
            echo "Error: Seed $SEED for task $TASK failed!"
            exit 1
        fi
    done
    
    echo "All 5 seeds completed for task $TASK on fold $FOLD!"
    echo "Collecting and aggregating statistics..."
    
    # Collect statistics for multi-seed experiment
    echo "Attempting statistics collection for multi-seed experiment..."
    python ../collect_seed_statistics.py \
        --experiment_dir ../experiments \
        --model umse \
        --task $TASK \
        --fold $FOLD \
        --seeds ${SEEDS[@]} \
        --batch_size $BATCH_SIZE \
        --lr $LR \
        --patience $PATIENCE \
        --epochs $EPOCHS \
        --dropout $DROPOUT \
        --d_model $D_MODEL \
        --variables_num $VARIABLES_NUM \
        --num_layers $NUM_LAYERS \
        --num_heads $NUM_HEADS \
        --n_modality $N_MODALITY \
        --bottlenecks_n $BOTTLENECKS_N \
        --max_ehr_len $MAX_EHR_LEN \
        --num_classes $NUM_CLASSES \
        --matched $MATCHED \
        --use_demographics $USE_DEMOGRAPHICS \
        --use_triplet $USE_TRIPLET
    
    if [ $? -eq 0 ]; then
        echo "Statistics collection completed successfully for task $TASK!"
    else
        echo "Warning: Statistics collection failed for task $TASK, but training completed."
    fi
}

# Main execution
echo "UMSE Multi-seed Training Script"
echo "==============================="
echo "Configuration:"
echo "  Batch Size: $BATCH_SIZE"
echo "  Learning Rate: $LR"
echo "  Epochs: $EPOCHS"
echo "  Patience: $PATIENCE"
echo "  Fold: $FOLD"
echo "  GPU: $GPU"
echo "  Seeds: ${SEEDS[*]}"
echo "  Tasks: ${TASKS[*]}"
echo "  D Model: $D_MODEL"
echo "  Variables Num: $VARIABLES_NUM"
echo "  Num Layers: $NUM_LAYERS"
echo "  Num Heads: $NUM_HEADS"
echo "  N Modality: $N_MODALITY"
echo "  Bottlenecks N: $BOTTLENECKS_N"
echo "  Max EHR Len: $MAX_EHR_LEN"
echo "  Dropout: $DROPOUT"
echo "  Matched: $MATCHED"
echo "  Use Demographics: $USE_DEMOGRAPHICS"
echo "  Use Triplet: $USE_TRIPLET"
echo "  Cross Eval: $CROSS_EVAL"
echo ""

# Run training for each task
for TASK in "${TASKS[@]}"
do
    run_task_training $TASK
done

echo ""
echo "All tasks completed successfully!"
echo "UMSE training with triplet format completed for all seeds and tasks."