#!/bin/bash

echo "Starting 3-seed training for SMIL model on fold 1..."

# Define experiment parameters
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
DROPOUT=0.2
FOLD=1                     # Fixed fold 1
GPU=2                      # GPU device number
MATCHED=false
USE_DEMOGRAPHICS=false
CROSS_EVAL=""              # Set to "matched_to_full" or "full_to_matched" if needed
PRETRAINED=true

# SMIL-specific parameters
EHR_ENCODER="transformer"         # Options: "lstm", "transformer" 
CXR_ENCODER="resnet50"     # Options: "resnet50", "vit_b_16"
HIDDEN_DIM=256             # Hidden dimension for SMIL
INNER_LOOP=2               # Number of inner loops for meta-learning
LR_INNER=0.0005106366481617694     # Inner learning rate
MC_SIZE=10                 # Monte Carlo size
ALPHA=0.1460970829212307                 # Feature distillation weight
BETA=0.1829896904844842   # EHR mean distillation weight
TEMPERATURE=2.1659838411477903            # Knowledge distillation temperature
N_CLUSTERS=10              # Number of clusters for CXR k-means

# EHR LSTM parameters (used when EHR_ENCODER="lstm")
EHR_NUM_LAYERS=1
EHR_BIDIRECTIONAL=true

# EHR Transformer parameters (used when EHR_ENCODER="transformer")
EHR_N_HEAD=4
EHR_N_LAYERS=1
MAX_LEN=500

# Seeds configuration (including 42)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("mortality")

for TASK in "${TASKS[@]}"
do
    echo "Training SMIL model for task: $TASK"
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49          # Note: SMIL uses 49 for this config
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=10
        INPUT_DIM=49
    fi
    
    # Check if CXR k-means centers are available
    DATA_TYPE="matched"
    if [ "$MATCHED" = "false" ]; then
        DATA_TYPE="full"
    fi
    
    echo "Checking CXR k-means centers availability for fold $FOLD..."
    CXR_MEAN_FILE="../models/smil/cxr_mean/cxr_mean_fold${FOLD}_${DATA_TYPE}_${CXR_ENCODER}_${N_CLUSTERS}clusters.npy"
    if [ ! -f "$CXR_MEAN_FILE" ]; then
        echo "Warning: CXR k-means file not found: $CXR_MEAN_FILE"
        echo "Please run CXR k-means computation first:"
        echo "cd ../models/smil && ./compute_cxr_kmeans.sh --task $TASK --folds $FOLD --data_type $DATA_TYPE --cxr_encoder $CXR_ENCODER --n_clusters $N_CLUSTERS --gpu $GPU"
        echo "Continuing without k-means check (assuming file will be created during training)..."
    else
        echo "CXR k-means centers are available!"
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Training with seed $SEED for task $TASK on fold $FOLD..."
        
        # Build base command
        CMD="python ../main.py \
            --model smil \
            --mode train \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --dropout $DROPOUT \
            --seed $SEED \
            --hidden_dim $HIDDEN_DIM \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --inner_loop $INNER_LOOP \
            --lr_inner $LR_INNER \
            --mc_size $MC_SIZE \
            --alpha $ALPHA \
            --beta $BETA \
            --temperature $TEMPERATURE \
            --n_clusters $N_CLUSTERS \
            --gpu $GPU"
        
        # Add encoder-specific parameters
        if [ "$EHR_ENCODER" = "lstm" ]; then
            CMD="$CMD --ehr_num_layers $EHR_NUM_LAYERS"
            if [ "$EHR_BIDIRECTIONAL" = "true" ]; then
                CMD="$CMD --ehr_bidirectional"
            fi
        elif [ "$EHR_ENCODER" = "transformer" ]; then
            CMD="$CMD --ehr_n_head $EHR_N_HEAD --ehr_n_layers $EHR_N_LAYERS --max_len $MAX_LEN"
        fi
        
        # Add conditional parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$MATCHED" = "true" ]; then
            CMD="$CMD --matched"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        fi
        
        echo "Running command: $CMD"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK completed successfully!"
        else
            echo "Error: Seed $SEED for task $TASK failed!"
            exit 1
        fi
    done
    
    echo "All 3 seeds completed for task $TASK on fold $FOLD!"
    echo "Task $TASK training and evaluation completed for fold $FOLD with 3 seeds!"
    echo "=================================="
done

echo "All tasks completed successfully!"
echo "Results can be found in experiments/smil/[task_name]/"
echo "Each seed experiment creates a separate checkpoint and log directory"