#!/bin/bash

# ============================================================================
# MMTM (Multi-Modal Transformer Model) TRAINING SCRIPT
# ============================================================================

# Model Configuration
MODEL="mmtm"
TASK="los"  # Options: 'phenotype', 'mortality', 'los'
GPU=3

# Basic Experiment Settings
PRETRAINED=true
USE_DEMOGRAPHICS=false
CROSS_EVAL=""  # Set to "matched_to_full" or "full_to_matched" if needed
MATCHED=true

# Training Parameters (based on mmtm.yaml)
BATCH_SIZE=16
EPOCHS=50
PATIENCE=10
LEARNING_RATE=0.0001
DROPOUT=0.2

# Seeds for multiple runs
SEEDS=(42 123 1234)

# MMTM-specific Architecture Parameters (based on mmtm.yaml)
DIM=256                    # Hidden dimension
INPUT_DIM=49             # EHR input dimension
NUM_CLASSES=7             # For los task (auto-adjusted for other tasks)

# Encoder Configuration (based on mmtm.yaml)
EHR_ENCODER="transformer"        # Options: 'lstm', 'transformer'
CXR_ENCODER="resnet50"    # Options: 'resnet50', 'vit_b_16'

# EHR Encoder Parameters (based on mmtm.yaml)
EHR_NUM_LAYERS=1          # For LSTM encoder
EHR_BIDIRECTIONAL=true   
EHR_N_HEAD=4              # For Transformer encoder
EHR_N_LAYERS_TRANS=1      # For Transformer encoder

# MMTM Fusion Parameters (based on mmtm.yaml)
MMTM_RATIO=4              # MMTM compression ratio
LAYER_AFTER=-1            # Which layer to apply MMTM fusion (-1 for all layers)

# Data Configuration
DATA_PAIRS="paired_ehr_cxr"

# Loss and Alignment Parameters (based on mmtm.yaml)
ALIGN=0.0                 # Disable alignment loss
USE_LABEL_WEIGHTS=false   # Enable/disable label weights
LABEL_WEIGHT_METHOD="balanced"  # Options: 'balanced', 'inverse', 'sqrt_inverse', 'log_inverse', 'custom'

# Function to determine log directory based on task and matched status
get_log_dir() {
    local TASK=$1
    local MATCHED=$2
    local MODEL="mmtm"  
    
    if [ "$MATCHED" = "true" ]; then
        echo "../experiments_${TASK}_matched/${MODEL}"
    else
        echo "../experiments_${TASK}_full/${MODEL}"
    fi
}

# ============================================================================
# SCRIPT IMPLEMENTATION - GENERALLY NO NEED TO MODIFY BELOW THIS LINE
# ============================================================================

# Function to generate dynamic results directory
generate_results_dir() {
    local model=$1
    local task=$2
    local use_demographics=$3
    local cross_eval=$4
    local matched=$5
    local pretrained=$6
    
    local demographic_str
    if [ "$use_demographics" = "true" ]; then
        demographic_str="demo"
    else
        demographic_str="no_demo"
    fi
    
    local matched_str
    if [ "$matched" = "true" ]; then
        matched_str="matched"
    else
        matched_str="full"
    fi
    
    local pretrained_str
    if [ "$pretrained" = "true" ]; then
        pretrained_str="pretrained"
    else
        pretrained_str="no_pretrained"
    fi
    
    # Handle cross_eval parameter
    local cross_eval_str
    if [ -n "$cross_eval" ]; then
        cross_eval_str="$cross_eval"
    else
        cross_eval_str="standard"
    fi
    
    # Generate results directory name
    local results_dirname="${model}_${task}-${demographic_str}-${cross_eval_str}-${matched_str}-${pretrained_str}"
    
    echo "${BASE_DIR}/../experiments/${model}/${task}/lightning_logs/${results_dirname}"
}

BASE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
RESULTS_DIR=$(generate_results_dir "$MODEL" "$TASK" "$USE_DEMOGRAPHICS" "$CROSS_EVAL" "$MATCHED" "$PRETRAINED")
LOG_FILE="${RESULTS_DIR}/training_$(date +%Y%m%d_%H%M%S).log"

# Create results directory
mkdir -p "$RESULTS_DIR"

# Function to log messages
log() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$LOG_FILE"
}

# Function to run training with specific seed
run_training() {
    local seed=$1
    local fold=$2
    
    log "Starting MMTM training - Seed: $seed, Fold: $fold"
    
    # Auto-adjust num_classes for different tasks
    local num_classes=$NUM_CLASSES
    if [ "$TASK" = "mortality" ]; then
        num_classes=1
    elif [ "$TASK" = "phenotype" ]; then
        num_classes=25
    elif [ "$TASK" = "los" ]; then
        num_classes=7
    fi
    
    # Determine log directory based on task and matched status
    local LOG_DIR=$(get_log_dir $TASK $MATCHED)
    log "Using log directory: $LOG_DIR"
    
    # Build training command based on mmtm.yaml parameters
    local cmd=(
        "python" "../main.py"
        "--model" "$MODEL"
        "--mode" "train"
        "--task" "$TASK"
        "--fold" "$fold"
        "--gpu" "$GPU"
        "--batch_size" "$BATCH_SIZE"
        "--epochs" "$EPOCHS"
        "--patience" "$PATIENCE"
        "--lr" "$LEARNING_RATE"
        "--dropout" "$DROPOUT"
        "--input_dim" "$INPUT_DIM"
        "--num_classes" "$num_classes"
        "--dim" "$DIM"
        "--ehr_encoder" "$EHR_ENCODER"
        "--cxr_encoder" "$CXR_ENCODER"
        "--data_pairs" "$DATA_PAIRS"
        "--mmtm_ratio" "$MMTM_RATIO"
        "--layer_after" "$LAYER_AFTER"
        "--align" "$ALIGN"
        "--seed" "$seed"
        "--log_dir" "$LOG_DIR"
    )
    
    # Add EHR encoder specific parameters
    if [ "$EHR_ENCODER" = "lstm" ]; then
        cmd+=("--ehr_num_layers" "$EHR_NUM_LAYERS")
        # 布尔参数作为标志传递
        if [ "$EHR_BIDIRECTIONAL" = "true" ]; then
            cmd+=("--ehr_bidirectional")
        fi
    elif [ "$EHR_ENCODER" = "transformer" ]; then
        cmd+=("--ehr_n_head" "$EHR_N_HEAD")
        cmd+=("--ehr_n_layers" "$EHR_N_LAYERS_TRANS")
    fi
    
    # Add label weight parameters (只在需要时添加)
    if [ "$USE_LABEL_WEIGHTS" = "true" ]; then
        cmd+=("--use_label_weights")
        cmd+=("--label_weight_method" "$LABEL_WEIGHT_METHOD")
    fi
    
    # Add conditional parameters (布尔参数作为标志传递)
    if [ "$PRETRAINED" = "true" ]; then
        cmd+=("--pretrained")
    fi
    
    if [ "$MATCHED" = "true" ]; then
        cmd+=("--matched")
    fi
    
    if [ "$USE_DEMOGRAPHICS" = "true" ]; then
        cmd+=("--use_demographics")
    fi
    
    if [ -n "$CROSS_EVAL" ]; then
        cmd+=("--cross_eval" "$CROSS_EVAL")
    fi
    
    # Create experiment directory for this seed
    local exp_dir="${RESULTS_DIR}/seed${seed}_fold${fold}"
    mkdir -p "$exp_dir"
    
    # Log command
    log "Training command: ${cmd[*]}"
    log "Results will be saved to: $exp_dir"
    
    # Run training
    local start_time=$(date +%s)
    
    if cd "$BASE_DIR" && "${cmd[@]}" > "$exp_dir/output.log" 2>&1; then
        local end_time=$(date +%s)
        local duration=$((end_time - start_time))
        log "Training completed successfully in ${duration}s"
        
        # Extract and log final metrics
        if [ -f "$exp_dir/output.log" ]; then
            log "Final training metrics:"
            grep -E "(overall/PRAUC|overall/AUROC|overall/ACC|overall/F1)" "$exp_dir/output.log" | tail -4 | while read -r line; do
                log "  $line"
            done
        fi
    else
        local end_time=$(date +%s)
        local duration=$((end_time - start_time))
        log "Training failed after ${duration}s"
        log "Check logs at: $exp_dir/output.log"
        return 1
    fi
}

# Main execution
main() {
    log "Starting MMTM Training"
    log "Configuration: MODEL=$MODEL, TASK=$TASK, USE_DEMOGRAPHICS=$USE_DEMOGRAPHICS, CROSS_EVAL=$CROSS_EVAL, PRETRAINED=$PRETRAINED"
    log "Results will be saved to: $RESULTS_DIR"
    log "Log file: $LOG_FILE"
    log "Training parameters: BATCH_SIZE=$BATCH_SIZE, EPOCHS=$EPOCHS, LR=$LEARNING_RATE, DROPOUT=$DROPOUT"
    log "Architecture: DIM=$DIM, INPUT_DIM=$INPUT_DIM, NUM_CLASSES=$NUM_CLASSES"
    log "Encoders: EHR=$EHR_ENCODER, CXR=$CXR_ENCODER"
    log "MMTM Fusion: RATIO=$MMTM_RATIO, LAYER_AFTER=$LAYER_AFTER"
    log "EHR Encoder: NUM_LAYERS=$EHR_NUM_LAYERS, BIDIRECTIONAL=$EHR_BIDIRECTIONAL"
    
    # Run training for all seeds and folds
    local total_runs=$((${#SEEDS[@]} * 1))  # Assuming single fold for now
    local current_run=0
    
    for seed in "${SEEDS[@]}"; do
        for fold in 1; do  # Can be extended to multiple folds
            current_run=$((current_run + 1))
            log "Progress: $current_run/$total_runs - Seed: $seed, Fold: $fold"
            
            if ! run_training "$seed" "$fold"; then
                log "Training failed for seed $seed, fold $fold"
                # Continue with other seeds/folds
            fi
            
            log "Completed run $current_run/$total_runs"
            echo "----------------------------------------"
        done
    done
    
    log "All MMTM training runs completed!"
    log "Results saved to: $RESULTS_DIR"
    
    # Generate summary
    log "=== TRAINING SUMMARY ==="
    log "Total runs: $total_runs"
    log "Successful runs: $(find "$RESULTS_DIR" -name "output.log" -exec grep -l "overall/PRAUC\|overall/ACC" {} \; | wc -l)"
    log "Failed runs: $(find "$RESULTS_DIR" -name "output.log" -exec grep -L "overall/PRAUC\|overall/ACC" {} \; | wc -l)"
}

# Handle script interruption
cleanup() {
    log "MMTM training interrupted by user"
    exit 1
}

trap cleanup SIGINT SIGTERM

# Run main function
main "$@"
