#!/bin/bash

echo "Starting MMTM Model Testing with Fairness Evaluation..."

# Define experiment parameters (should match training parameters)
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=0                      # GPU device number
PRETRAINED=true

# Testing configuration with FAIRNESS ENABLED
USE_DEMOGRAPHICS=true                     # MUST be true for fairness evaluation
CROSS_EVAL=""              # Change this as needed

# ========== FAIRNESS EVALUATION CONFIGURATION ==========
COMPUTE_FAIRNESS=true                     # Enable fairness metrics computation
FAIRNESS_ATTRIBUTES="race gender age admission_type has_cxr"    # Sensitive attributes for fairness evaluation
FAIRNESS_AGE_BINS="0 20 40 60 80"       # Age bins: 0-20, 20-40, 40-60, 60-80, 80+
FAIRNESS_INTERSECTIONAL=true           # Enable intersectional fairness analysis

# ========== PREDICTIONS SAVING CONFIGURATION ==========
SAVE_PREDICTIONS=true                    # Save test predictions and labels to experiment directory

echo "Fairness Evaluation Configuration:"
echo "   - Compute Fairness: $COMPUTE_FAIRNESS"
echo "   - Sensitive Attributes: $FAIRNESS_ATTRIBUTES"
echo "   - Age Bins: $FAIRNESS_AGE_BINS"
echo "   - Intersectional Analysis: $FAIRNESS_INTERSECTIONAL"
echo "   - Demographics Enabled: $USE_DEMOGRAPHICS"

echo "Predictions Saving Configuration:"
echo "   - Save Predictions: $SAVE_PREDICTIONS"
echo "   - Save Directory: Experiment directory (automatic)"

# MMTM specific parameters (should match training)
EHR_ENCODER=transformer           # Options: 'lstm', 'transformer'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16'
DIM=256                    # Hidden dimension
EHR_NUM_LAYERS=1           # Number of LSTM layers
EHR_BIDIRECTIONAL=false    # Keep false for MMTM compatibility
EHR_N_HEAD=4               # Number of attention heads for transformer
EHR_N_LAYERS=1             # Number of transformer layers
MMTM_RATIO=4               # MMTM compression ratio
LAYER_AFTER=-1             # Which layer to apply MMTM fusion (-1 for all layers)
DROPOUT=0.2                # Dropout rate
ALIGN=0.0                  # Alignment loss weight

# Seeds configuration (should match training seeds)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("los")

# 改进的checkpoint查找函数 - 专门针对MMTM目录结构优化
find_checkpoint() {
    local base_dir="$1"
    
    echo "Searching for checkpoint in: $base_dir"
    
    # 首先检查目录是否存在
    if [ ! -d "$base_dir" ]; then
        echo "Error: Directory does not exist: $base_dir"
        echo ""
        return
    fi
    
    # 直接在基础目录中查找（MMTM checkpoints直接存放在实验目录中）
    echo "Searching directly in base directory"
    
    # 查找所有best_model文件并手动排序
    local checkpoint_files=($(find "$base_dir" -name "best_model*.ckpt" -type f 2>/dev/null))
    
    if [ ${#checkpoint_files[@]} -gt 0 ]; then
        echo "Found ${#checkpoint_files[@]} checkpoint files"
        
        # 找到PRAUC值最高的文件
        local best_prauc=0
        local best_checkpoint=""
        
        for file in "${checkpoint_files[@]}"; do
            # 提取PRAUC值 (格式: best_model_epoch_XX_prauc_X.XXXX.ckpt)
            if [[ "$file" =~ prauc_([0-9]+\.[0-9]+)\.ckpt ]]; then
                local prauc="${BASH_REMATCH[1]}"
                echo "Found checkpoint with PRAUC: $prauc - $file"
                
                # 使用bc进行浮点数比较，如果bc不可用则使用awk
                if command -v bc >/dev/null 2>&1; then
                    if (( $(echo "$prauc > $best_prauc" | bc -l) )); then
                        best_prauc="$prauc"
                        best_checkpoint="$file"
                    fi
                else
                    # 使用awk作为bc的替代
                    if awk "BEGIN {exit !($prauc > $best_prauc)}"; then
                        best_prauc="$prauc"
                        best_checkpoint="$file"
                    fi
                fi
            fi
        done
        
        if [ -n "$best_checkpoint" ]; then
            echo "Selected best checkpoint with PRAUC $best_prauc: $best_checkpoint"
            echo "$best_checkpoint"
            return
        fi
    fi
    
    # 如果上面的方法失败，使用简单的字母排序（通常最后的文件PRAUC最高）
    local fallback_checkpoint=$(find "$base_dir" -name "best_model*.ckpt" -type f 2>/dev/null | sort | tail -1)
    if [ -n "$fallback_checkpoint" ] && [ -f "$fallback_checkpoint" ]; then
        echo "Found checkpoint (fallback sort): $fallback_checkpoint"
        echo "$fallback_checkpoint"
        return
    fi
    
    # 如果没有best_model文件，查找任何.ckpt文件
    local checkpoint_file=$(find "$base_dir" -name "*.ckpt" -type f 2>/dev/null | head -1)
    if [ -n "$checkpoint_file" ] && [ -f "$checkpoint_file" ]; then
        echo "Found checkpoint in base dir: $checkpoint_file"
        echo "$checkpoint_file"
        return
    fi
    
    # 列出目录内容用于调试
    echo "Directory contents:"
    ls -la "$base_dir" 2>/dev/null || echo "Directory does not exist"
    
    # 如果还找不到，返回空
    echo ""
}

# Function to determine training data config for checkpoint path
get_training_data_config() {
    local cross_eval="$1"
    local task="$2"
    local seed="$3"
    
    if [ "$cross_eval" = "matched_to_full" ]; then
        echo "matched"
    elif [ "$cross_eval" = "full_to_matched" ]; then
        echo "full"
    else
        # For normal testing, automatically detect available checkpoint
        # Try full first (since that's what we have), then matched
        local base_path="../experiments/mmtm/lightning_logs"
        local full_dir="${base_path}/MMTM-model_mmtm-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${seed}-align_${ALIGN}-pretrained_True-data_config_full"
        local matched_dir="${base_path}/MMTM-model_mmtm-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${seed}-align_${ALIGN}-pretrained_True-data_config_matched"
        
        if [ -d "$full_dir" ]; then
            echo "full"
        elif [ -d "$matched_dir" ]; then
            echo "matched"
        else
            echo "full"  # Default fallback to full since that's what we have
        fi
    fi
}

# Function to extract and display fairness metrics
display_fairness_results() {
    local results_file="$1"
    local seed="$2"
    local task="$3"
    
    if [ -f "$results_file" ]; then
        echo ""
        echo "========== FAIRNESS RESULTS SUMMARY (Seed $seed, Task $task) =========="
        
        # Determine which metric to use based on task type
        if [ "$task" = "los" ]; then
            METRIC_NAME="accuracy"
            METRIC_DISPLAY="ACC"
        else
            METRIC_NAME="PRAUC"
            METRIC_DISPLAY="PRAUC"
        fi
        
        # Extract gap metrics
        echo "${METRIC_DISPLAY} Gap Analysis:"
        grep -E "fairness.*${METRIC_NAME}.*gap:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr ${METRIC_DISPLAY} gap" "$value"
        done
        
        # Extract worst-case performance
        echo ""
        echo "Worst-Case Performance:"
        grep -E "fairness.*${METRIC_NAME}.*worst_case:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr worst ${METRIC_DISPLAY}" "$value"
        done
        
        # Extract group-wise performance
        echo ""
        echo "Group-wise ${METRIC_DISPLAY} Performance:"
        grep -E "fairness.*/[^/]*/${METRIC_NAME}:" "$results_file" | head -10 | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            group=$(echo "$metric" | cut -d'/' -f3)
            printf "   %-15s %-15s: %8s\n" "$attr" "$group" "$value"
        done
        
        echo "=================================================================="
    else
        echo "Warning: Results file not found: $results_file"
    fi
}

# Create root directory name
if [ -n "$CROSS_EVAL" ]; then
    ROOT_DIR="../experiments_fairness/mmtm-${CROSS_EVAL}-fairness"
else
    ROOT_DIR="../experiments_fairness/mmtm-fairness"
fi

echo "All results will be saved under: $ROOT_DIR"

for TASK in "${TASKS[@]}"
do
    if [ -n "$CROSS_EVAL" ]; then
        echo ""
        echo "Testing MMTM model for task: $TASK with cross evaluation: $CROSS_EVAL and fairness evaluation"
    else
        echo ""
        echo "Testing MMTM model for task: $TASK with fairness evaluation (normal test)"
    fi
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=10
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo ""
        echo "Testing with seed $SEED for task $TASK on fold $FOLD..."
        
        # Determine the training data config for finding checkpoint
        TRAINING_DATA_CONFIG=$(get_training_data_config "$CROSS_EVAL" "$TASK" "$SEED")
        echo "Using training data config: $TRAINING_DATA_CONFIG"
        
        # Build checkpoint directory path - 修复：MMTM检查点直接存放在实验目录中，不需要/checkpoints子目录
        CHECKPOINT_DIR="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_${TASK}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${SEED}-align_${ALIGN}-pretrained_True-data_config_${TRAINING_DATA_CONFIG}"
        
        # 手动指定最佳checkpoint路径 - 根据任务和种子动态选择
        case "$TASK-$SEED" in
            # LoS task checkpoints - 修复：改为data_config_full
            "los-42")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_los-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_42-align_0.0-pretrained_True-data_config_full/best_model_epoch_00_prauc_0.0000.ckpt"
                ;;
            "los-123")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_los-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_123-align_0.0-pretrained_True-data_config_full/best_model_epoch_00_prauc_0.0000.ckpt"
                ;;
            "los-1234")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_los-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_1234-align_0.0-pretrained_True-data_config_full/best_model_epoch_00_prauc_0.0000.ckpt"
                ;;
            # Mortality task checkpoints - 修复：改为data_config_full
            "mortality-1234")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_mortality-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_1234-align_0.0-pretrained_True-data_config_full/best_model_epoch_13_prauc_0.2844.ckpt"
                ;;
            "mortality-123")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_mortality-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_123-align_0.0-pretrained_True-data_config_full/best_model_epoch_27_prauc_0.3624.ckpt"
                ;;
            "mortality-42")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_mortality-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_42-align_0.0-pretrained_True-data_config_full/best_model_epoch_07_prauc_0.3187.ckpt"
                ;;
            # Phenotype task checkpoints - 修复：改为data_config_full
            "phenotype-1234")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_phenotype-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_1234-align_0.0-pretrained_True-data_config_full/best_model_epoch_21_prauc_0.4823.ckpt"
                ;;
            "phenotype-123")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_phenotype-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_123-align_0.0-pretrained_True-data_config_full/best_model_epoch_15_prauc_0.4865.ckpt"
                ;;
            "phenotype-42")
                CHECKPOINT_FILE="../experiments/mmtm/lightning_logs/MMTM-model_mmtm-task_phenotype-fold_1-batch_size_16-lr_0.0001-patience_10-epochs_50-dropout_0.2-seed_42-align_0.0-pretrained_True-data_config_full/best_model_epoch_16_prauc_0.4829.ckpt"
                ;;
            *)
                # 使用原来的查找函数作为fallback
                CHECKPOINT_FILE=$(find_checkpoint "$CHECKPOINT_DIR")
                ;;
        esac
        
        if [ ! -f "$CHECKPOINT_FILE" ]; then
            echo "Error: Checkpoint not found in $CHECKPOINT_DIR"
            echo "Available directories for seed $SEED:"
            ls -la "../experiments/mmtm/lightning_logs/" | grep "seed_${SEED}" || echo "No directories found for seed $SEED"
            
            # Try to find any available checkpoint for this seed
            echo "Searching for any available checkpoint for seed $SEED..."
            ALTERNATIVE_DIRS=$(ls -d "../experiments/mmtm/lightning_logs/"*"seed_${SEED}"* 2>/dev/null || echo "")
            
            if [ -n "$ALTERNATIVE_DIRS" ]; then
                for alt_dir in $ALTERNATIVE_DIRS; do
                    alt_checkpoint=$(find_checkpoint "$alt_dir")
                    if [ -f "$alt_checkpoint" ]; then
                        echo "Found alternative checkpoint: $alt_checkpoint"
                        CHECKPOINT_FILE="$alt_checkpoint"
                        if [[ "$alt_dir" == *"data_config_matched"* ]]; then
                            TRAINING_DATA_CONFIG="matched"
                        elif [[ "$alt_dir" == *"data_config_full"* ]]; then
                            TRAINING_DATA_CONFIG="full"
                        fi
                        echo "Updated training data config to: $TRAINING_DATA_CONFIG"
                        break
                    fi
                done
            fi
            
            if [ ! -f "$CHECKPOINT_FILE" ]; then
                echo "No valid checkpoint found for seed $SEED, skipping..."
                continue
            fi
        fi
        
        echo "Found checkpoint: $CHECKPOINT_FILE"
        echo "Using data config: $TRAINING_DATA_CONFIG"
        
        # Create custom output directory
        CUSTOM_LOG_DIR="$ROOT_DIR/$TASK"
        
        # Build base command with fairness options
        CMD="python ../main.py \
            --model mmtm \
            --mode test \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --dim $DIM \
            --ehr_num_layers $EHR_NUM_LAYERS \
            --ehr_n_head $EHR_N_HEAD \
            --ehr_n_layers $EHR_N_LAYERS \
            --mmtm_ratio $MMTM_RATIO \
            --layer_after $LAYER_AFTER \
            --dropout $DROPOUT \
            --align $ALIGN \
            --gpu $GPU \
            --checkpoint_path $CHECKPOINT_FILE \
            --log_dir $CUSTOM_LOG_DIR"
        
        # Add conditional parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$EHR_BIDIRECTIONAL" = "true" ]; then
            CMD="$CMD --ehr_bidirectional"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        # Add fairness evaluation parameters
        if [ "$COMPUTE_FAIRNESS" = "true" ]; then
            CMD="$CMD --compute_fairness"
            
            if [ -n "$FAIRNESS_ATTRIBUTES" ]; then
                CMD="$CMD --fairness_attributes $FAIRNESS_ATTRIBUTES"
            fi
            
            if [ -n "$FAIRNESS_AGE_BINS" ]; then
                CMD="$CMD --fairness_age_bins $FAIRNESS_AGE_BINS"
            fi
            
            if [ "$FAIRNESS_INTERSECTIONAL" = "true" ]; then
                CMD="$CMD --fairness_intersectional"
            fi
        fi
        
        # Add predictions saving parameters
        if [ "$SAVE_PREDICTIONS" = "true" ]; then
            CMD="$CMD --save_predictions"
        fi
        
        # Add cross evaluation parameter if specified
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        else
            if [ "$TRAINING_DATA_CONFIG" = "matched" ]; then
                CMD="$CMD --matched"
            # 对于 full 配置，不需要添加任何参数，因为默认就是 full
            fi
        fi
        
        echo "Running command:"
        echo "$CMD"
        echo ""
        echo "Output will be saved to: $CUSTOM_LOG_DIR"
        echo "Fairness metrics will be computed for: $FAIRNESS_ATTRIBUTES"
        
        # Execute the command
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK testing with fairness evaluation completed successfully!"
            echo "Results saved to: $CUSTOM_LOG_DIR"
            
            # Find the most recent results file
            RESULTS_PATTERN="$CUSTOM_LOG_DIR/lightning_logs/*/test_set_results.yaml"
            LATEST_RESULTS=$(ls -t $RESULTS_PATTERN 2>/dev/null | head -1)
            
            if [ -f "$LATEST_RESULTS" ]; then
                echo "Fairness metrics generated successfully!"
                
                # Display detailed fairness results
                display_fairness_results "$LATEST_RESULTS" "$SEED" "$TASK"
                
                # Save a summary for this seed
                SUMMARY_FILE="$CUSTOM_LOG_DIR/fairness_summary_seed_${SEED}.txt"
                echo "Fairness Summary for MMTM - Task: $TASK, Seed: $SEED" > "$SUMMARY_FILE"
                echo "Generated on: $(date)" >> "$SUMMARY_FILE"
                echo "=======================================" >> "$SUMMARY_FILE"
                
                # Extract key metrics to summary
                echo "" >> "$SUMMARY_FILE"
                echo "PRAUC Gaps:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*gap:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "" >> "$SUMMARY_FILE"
                echo "Worst-case Performance:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*worst_case:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "Fairness summary saved to: $SUMMARY_FILE"
                
            else
                echo "Warning: Results file not found at expected location"
                echo "Expected pattern: $RESULTS_PATTERN"
            fi
        else
            echo "Error: Seed $SEED for task $TASK testing failed!"
            exit 1
        fi
        
        echo ""
        echo "-----------------------------------"
    done
    
    # Task completion summary
    echo ""
    if [ -n "$CROSS_EVAL" ]; then
        echo "All 3 seeds cross evaluation with fairness analysis completed for task $TASK on fold $FOLD!"
        echo "Cross evaluation type: $CROSS_EVAL"
    else
        echo "All 3 seeds testing with fairness analysis completed for task $TASK on fold $FOLD!"
    fi
    echo "Task $TASK results saved to: $ROOT_DIR/$TASK/"
    echo "Fairness evaluation results included in test_set_results.yaml files"
    echo "Individual fairness summaries saved as fairness_summary_seed_*.txt"
    if [ "$SAVE_PREDICTIONS" = "true" ]; then
        echo "Test predictions and labels saved as test_predictions_fold${FOLD}_seed*.npz/csv"
    fi
    echo "=================================="
done
