#!/bin/bash

echo "Starting 3-seed training for MedFuse model on fold 1..."

# Define experiment parameters
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
DROPOUT=0.2
FOLD=1                     # Fixed fold 1
GPU=5                      # GPU device number
MATCHED=false
USE_DEMOGRAPHICS=false
CROSS_EVAL=""              # Set to "matched_to_full" or "full_to_matched" if needed

# MedFuse specific parameters
EHR_ENCODER=lstm           # Options: 'lstm', 'transformer', 'drfuse'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16', 'medfuse_cxr'
DIM=256                    # Feature dimension
LAYERS=1                   # Number of LSTM layers
VISION_BACKBONE=resnet50   # Vision backbone
FUSION_TYPE=lstm           # Fusion type: 'early', 'late', 'uni', 'lstm'
PRETRAINED=true            # Use pretrained models
DRFUSE_ENCODER=false       # Use DrFuse encoder components

# Seeds configuration (including 42)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("los")

# Function to determine log directory based on task and matched status
get_log_dir() {
    local TASK=$1
    local MATCHED=$2
    local MODEL="medfuse"  
    
    if [ "$MATCHED" = "true" ]; then
        echo "../experiments_${TASK}_matched/${MODEL}"
    else
        echo "../experiments_${TASK}_full/${MODEL}"
    fi
}

# Function to run training for a specific task
run_task_training() {
    local TASK=$1
    echo "Starting training for task: $TASK"
    
    # Determine log directory based on task and matched status
    LOG_DIR=$(get_log_dir $TASK $MATCHED)
    echo "Using log directory: $LOG_DIR"
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        VISION_NUM_CLASSES=25
        INPUT_DIM=498
        LABELS_SET=phenotype
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        VISION_NUM_CLASSES=1
        INPUT_DIM=498
        LABELS_SET=mortality
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=7
        VISION_NUM_CLASSES=7
        INPUT_DIM=49
        LABELS_SET=los
    else
        echo "Error: Unknown task $TASK"
        exit 1
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Training with seed $SEED for task $TASK on fold $FOLD..."
        
        # Build base command
        CMD="python ../main.py \
            --model medfuse \
            --mode train \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --dropout $DROPOUT \
            --seed $SEED \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --dim $DIM \
            --layers $LAYERS \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --vision_backbone $VISION_BACKBONE \
            --vision_num_classes $VISION_NUM_CLASSES \
            --labels_set $LABELS_SET \
            --fusion_type $FUSION_TYPE \
            --log_dir $LOG_DIR \
            --gpu $GPU"
        
        # Add MedFuse specific parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$DRFUSE_ENCODER" = "true" ]; then
            CMD="$CMD --drfuse_encoder"
        fi
        
        # Add conditional parameters
        if [ "$MATCHED" = "true" ]; then
            CMD="$CMD --matched"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        fi
        
        echo "Running command: $CMD"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK completed successfully!"
        else
            echo "Error: Seed $SEED for task $TASK failed!"
            exit 1
        fi
    done
    
    echo "All 3 seeds completed for task $TASK on fold $FOLD!"
}

# Run training for each task
for TASK in "${TASKS[@]}"
do
    run_task_training $TASK
done

echo "All tasks completed successfully!"
echo "Training results are saved in the appropriate log directory based on task and matched status."