#!/bin/bash

echo "Starting 5-seed training for UTDE model on fold 1..."

# Define experiment parameters
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=1                      # GPU device number
MATCHED=true
USE_DEMOGRAPHICS=false
CROSS_EVAL=""              # Set to "matched_to_full" or "full_to_matched" if needed

# UTDE specific parameters
EHR_ENCODER=transformer    # Options: 'lstm', 'transformer'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16'
EMBED_DIM=256              # Hidden dimension for both EHR and CXR
EMBED_TIME=64              # Time embedding dimension
NUM_HEADS=4                # Multi-head attention heads
TT_MAX=500                 # Maximum time steps
CROSS_LAYERS=1             # Cross-modal transformer layers
DROPOUT=0.2                # Dropout rate
PRETRAINED=true            # Use pretrained CXR encoder

# EHR Encoder specific parameters
EHR_NUM_LAYERS=1           
EHR_BIDIRECTIONAL=true     # For LSTM
EHR_HIDDEN_DIM=256         
EHR_N_HEAD=4               # For Transformer

# Seeds configuration
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("mortality")

# Function to run training for a specific task
run_task_training() {
    local TASK=$1
    echo "Starting training for task: $TASK"
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Training with seed $SEED for task $TASK on fold $FOLD..."
        
        # Build base command
        CMD="python ../main.py \
            --model utde \
            --mode train \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --embed_dim $EMBED_DIM \
            --embed_time $EMBED_TIME \
            --num_heads $NUM_HEADS \
            --tt_max $TT_MAX \
            --cross_layers $CROSS_LAYERS \
            --dropout $DROPOUT \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --gpu $GPU"
        
        # Add EHR encoder specific parameters
        if [ "$EHR_ENCODER" = "lstm" ]; then
            CMD="$CMD --ehr_num_layers $EHR_NUM_LAYERS --ehr_hidden_dim $EHR_HIDDEN_DIM"
            if [ "$EHR_BIDIRECTIONAL" = "true" ]; then
                CMD="$CMD --ehr_bidirectional"
            fi
        elif [ "$EHR_ENCODER" = "transformer" ]; then
            CMD="$CMD --ehr_n_head $EHR_N_HEAD"
        fi
        
        # Add UTDE specific parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        # Add conditional parameters
        if [ "$MATCHED" = "true" ]; then
            CMD="$CMD --matched"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        fi
        
        echo "Running command: $CMD"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK completed successfully!"
        else
            echo "Error: Seed $SEED for task $TASK failed!"
            exit 1
        fi
    done
    
    echo "All 5 seeds completed for task $TASK on fold $FOLD!"
    echo "Collecting and aggregating statistics..."
    
    # Collect statistics for multi-seed experiment
    echo "Attempting statistics collection for multi-seed experiment..."
    python ../collect_seed_statistics.py \
        --experiment_dir ../experiments \
        --model utde \
        --task $TASK \
        --fold $FOLD \
        --seeds ${SEEDS[@]} \
        --batch_size $BATCH_SIZE \
        --lr $LR \
        --patience $PATIENCE \
        --epochs $EPOCHS \
        --ehr_encoder $EHR_ENCODER \
        --cxr_encoder $CXR_ENCODER \
        --embed_dim $EMBED_DIM \
        --embed_time $EMBED_TIME \
        --num_heads $NUM_HEADS \
        --tt_max $TT_MAX \
        --cross_layers $CROSS_LAYERS \
        --dropout $DROPOUT \
        --input_dim $INPUT_DIM \
        --num_classes $NUM_CLASSES \
        --pretrained $PRETRAINED \
        --matched $MATCHED \
        --use_demographics $USE_DEMOGRAPHICS
    
    if [ $? -eq 0 ]; then
        echo "Statistics collection completed successfully for task $TASK!"
    else
        echo "Warning: Statistics collection failed for task $TASK, but training completed."
    fi
}

# Main execution
echo "UTDE Multi-seed Training Script"
echo "==============================="
echo "Configuration:"
echo "  Batch Size: $BATCH_SIZE"
echo "  Learning Rate: $LR"
echo "  Epochs: $EPOCHS"
echo "  Patience: $PATIENCE"
echo "  Fold: $FOLD"
echo "  GPU: $GPU"
echo "  Seeds: ${SEEDS[*]}"
echo "  Tasks: ${TASKS[*]}"
echo "  EHR Encoder: $EHR_ENCODER"
echo "  CXR Encoder: $CXR_ENCODER"
echo "  Embed Dim: $EMBED_DIM"
echo "  Embed Time: $EMBED_TIME"
echo "  Num Heads: $NUM_HEADS"
echo "  TT Max: $TT_MAX"
echo "  Cross Layers: $CROSS_LAYERS"
echo "  Dropout: $DROPOUT"
echo "  Pretrained: $PRETRAINED"
echo "  Matched: $MATCHED"
echo "  Use Demographics: $USE_DEMOGRAPHICS"
if [ -n "$CROSS_EVAL" ]; then
    echo "  Cross Eval: $CROSS_EVAL"
fi
echo "==============================="

# Run training for each task
for TASK in "${TASKS[@]}"
do
    run_task_training $TASK
done

echo "All tasks completed successfully!"
echo "Training results are saved in the experiments directory."