#!/bin/bash

echo "Starting M3Care Model Testing with Fairness Evaluation..."

# Define experiment parameters (should match training parameters)
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=0                      # GPU device number
PRETRAINED=true

# Testing configuration with FAIRNESS ENABLED
USE_DEMOGRAPHICS=true                     # MUST be true for fairness evaluation
CROSS_EVAL=""              # Change this as needed

# ========== FAIRNESS EVALUATION CONFIGURATION ==========
COMPUTE_FAIRNESS=true                     # Enable fairness metrics computation
FAIRNESS_ATTRIBUTES="race gender age admission_type has_cxr"    # Sensitive attributes for fairness evaluation
FAIRNESS_AGE_BINS="0 20 40 60 80"       # Age bins: 0-20, 20-40, 40-60, 60-80, 80+
FAIRNESS_INTERSECTIONAL=true           # Enable intersectional fairness analysis

# ========== PREDICTIONS SAVING CONFIGURATION ==========
SAVE_PREDICTIONS=true                    # Save test predictions and labels to experiment directory

echo "Fairness Evaluation Configuration:"
echo "   - Compute Fairness: $COMPUTE_FAIRNESS"
echo "   - Sensitive Attributes: $FAIRNESS_ATTRIBUTES"
echo "   - Age Bins: $FAIRNESS_AGE_BINS"
echo "   - Intersectional Analysis: $FAIRNESS_INTERSECTIONAL"
echo "   - Demographics Enabled: $USE_DEMOGRAPHICS"

echo "Predictions Saving Configuration:"
echo "   - Save Predictions: $SAVE_PREDICTIONS"
echo "   - Save Directory: Experiment directory (automatic)"

# M3Care specific parameters (should match training)
EHR_ENCODER=transformer    # Options: 'lstm', 'transformer'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16'
HIDDEN_DIM=256             # Hidden dimension
EHR_NUM_LAYERS=1           # Number of LSTM layers
EHR_BIDIRECTIONAL=true     # Bidirectional LSTM
EHR_N_HEAD=4               # Number of attention heads
EHR_N_LAYERS=1             # Number of transformer layers
EHR_DROPOUT=0.2            # EHR dropout
MAX_LEN=500                # Maximum sequence length
DROPOUT=0.2                # General dropout rate
STAB_REG_LAMBDA=0.1        # Stability regularization lambda

# Seeds configuration (should match training seeds)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("los")

# Function to find checkpoint file dynamically
find_checkpoint() {
    local checkpoint_dir="$1"
    local checkpoint_file=$(find "$checkpoint_dir" -name "*.ckpt" -type f | head -1)
    echo "$checkpoint_file"
}

# Function to determine training data config for checkpoint path
get_training_data_config() {
    local cross_eval="$1"
    local task="$2"
    local seed="$3"
    
    if [ "$cross_eval" = "matched_to_full" ]; then
        echo "matched"
    elif [ "$cross_eval" = "full_to_matched" ]; then
        echo "full"
    else
        # For normal testing, automatically detect available checkpoint
        # Try matched first, then full
        local base_path="../experiments/m3care/$task/lightning_logs"
        local matched_dir="${base_path}/M3CARE-model_m3care-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-seed_${seed}-pretrained_True-data_config_matched"
        local full_dir="${base_path}/M3CARE-model_m3care-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-seed_${seed}-pretrained_True-data_config_full"
        
        if [ -d "$matched_dir" ]; then
            echo "matched"
        elif [ -d "$full_dir" ]; then
            echo "full"
        else
            echo "matched"  # Default fallback
        fi
    fi
}

# Function to extract and display fairness metrics
display_fairness_results() {
    local results_file="$1"
    local seed="$2"
    local task="$3"
    
    if [ -f "$results_file" ]; then
        echo ""
        echo "========== FAIRNESS RESULTS SUMMARY (Seed $seed, Task $task) =========="
        
        # Determine which metric to use based on task type
        if [ "$task" = "los" ]; then
            METRIC_NAME="accuracy"
            METRIC_DISPLAY="ACC"
        else
            METRIC_NAME="PRAUC"
            METRIC_DISPLAY="PRAUC"
        fi
        
        # Extract gap metrics
        echo "${METRIC_DISPLAY} Gap Analysis:"
        grep -E "fairness.*${METRIC_NAME}.*gap:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr ${METRIC_DISPLAY} gap" "$value"
        done
        
        # Extract worst-case performance
        echo ""
        echo "Worst-Case Performance:"
        grep -E "fairness.*${METRIC_NAME}.*worst_case:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr worst ${METRIC_DISPLAY}" "$value"
        done
        
        # Extract group-wise performance
        echo ""
        echo "Group-wise ${METRIC_DISPLAY} Performance:"
        grep -E "fairness.*/[^/]*/${METRIC_NAME}:" "$results_file" | head -10 | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            group=$(echo "$metric" | cut -d'/' -f3)
            printf "   %-15s %-15s: %8s\n" "$attr" "$group" "$value"
        done
        
        echo "=================================================================="
    else
        echo "Warning: Results file not found: $results_file"
    fi
}


# Create root directory name
if [ -n "$CROSS_EVAL" ]; then
    ROOT_DIR="../experiments_fairness/m3care-${CROSS_EVAL}-fairness"
else
    ROOT_DIR="../experiments_fairness/m3care-fairness"
fi

echo "All results will be saved under: $ROOT_DIR"

for TASK in "${TASKS[@]}"
do
    if [ -n "$CROSS_EVAL" ]; then
        echo ""
        echo "Testing M3Care model for task: $TASK with cross evaluation: $CROSS_EVAL and fairness evaluation"
    else
        echo ""
        echo "Testing M3Care model for task: $TASK with fairness evaluation (normal test)"
    fi
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=498
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=498
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=10
        INPUT_DIM=498
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo ""
        echo "Testing with seed $SEED for task $TASK on fold $FOLD..."
        
        # Determine the training data config for finding checkpoint
        TRAINING_DATA_CONFIG=$(get_training_data_config "$CROSS_EVAL" "$TASK" "$SEED")
        echo "Using training data config: $TRAINING_DATA_CONFIG"
        
        # Build checkpoint directory path
        CHECKPOINT_DIR="../experiments/m3care/$TASK/lightning_logs/M3CARE-model_m3care-task_${TASK}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-seed_${SEED}-pretrained_True-data_config_${TRAINING_DATA_CONFIG}/checkpoints"
        
        # Find checkpoint file
        CHECKPOINT_FILE=$(find_checkpoint "$CHECKPOINT_DIR")
        
        if [ ! -f "$CHECKPOINT_FILE" ]; then
            echo "Error: Checkpoint not found in $CHECKPOINT_DIR"
            echo "Available directories for seed $SEED:"
            ls -la "../experiments/m3care/$TASK/lightning_logs/" | grep "seed_${SEED}" || echo "No directories found for seed $SEED"
            
            # Try to find any available checkpoint for this seed
            echo "Searching for any available checkpoint for seed $SEED..."
            ALTERNATIVE_DIRS=$(ls -d "../experiments/m3care/$TASK/lightning_logs/"*"seed_${SEED}"* 2>/dev/null || echo "")
            
            if [ -n "$ALTERNATIVE_DIRS" ]; then
                for alt_dir in $ALTERNATIVE_DIRS; do
                    alt_checkpoint_dir="$alt_dir/checkpoints"
                    alt_checkpoint=$(find_checkpoint "$alt_checkpoint_dir")
                    if [ -f "$alt_checkpoint" ]; then
                        echo "Found alternative checkpoint: $alt_checkpoint"
                        CHECKPOINT_FILE="$alt_checkpoint"
                        if [[ "$alt_dir" == *"data_config_matched"* ]]; then
                            TRAINING_DATA_CONFIG="matched"
                        elif [[ "$alt_dir" == *"data_config_full"* ]]; then
                            TRAINING_DATA_CONFIG="full"
                        fi
                        echo "Updated training data config to: $TRAINING_DATA_CONFIG"
                        break
                    fi
                done
            fi
            
            if [ ! -f "$CHECKPOINT_FILE" ]; then
                echo "No valid checkpoint found for seed $SEED, skipping..."
                continue
            fi
        fi
        
        echo "Found checkpoint: $CHECKPOINT_FILE"
        echo "Using data config: $TRAINING_DATA_CONFIG"
        
        # Create custom output directory
        CUSTOM_LOG_DIR="$ROOT_DIR/$TASK"
        
        # Build base command with fairness options
        CMD="python ../main.py \
            --model m3care \
            --mode test \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --hidden_dim $HIDDEN_DIM \
            --ehr_num_layers $EHR_NUM_LAYERS \
            --ehr_n_head $EHR_N_HEAD \
            --ehr_n_layers $EHR_N_LAYERS \
            --ehr_dropout $EHR_DROPOUT \
            --max_len $MAX_LEN \
            --dropout $DROPOUT \
            --stab_reg_lambda $STAB_REG_LAMBDA \
            --gpu $GPU \
            --checkpoint_path $CHECKPOINT_FILE \
            --log_dir $CUSTOM_LOG_DIR"
        
        # Add conditional parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$EHR_BIDIRECTIONAL" = "true" ]; then
            CMD="$CMD --ehr_bidirectional"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        # Add fairness evaluation parameters
        if [ "$COMPUTE_FAIRNESS" = "true" ]; then
            CMD="$CMD --compute_fairness"
            
            if [ -n "$FAIRNESS_ATTRIBUTES" ]; then
                CMD="$CMD --fairness_attributes $FAIRNESS_ATTRIBUTES"
            fi
            
            if [ -n "$FAIRNESS_AGE_BINS" ]; then
                CMD="$CMD --fairness_age_bins $FAIRNESS_AGE_BINS"
            fi
            
            if [ "$FAIRNESS_INTERSECTIONAL" = "true" ]; then
                CMD="$CMD --fairness_intersectional"
            fi
        fi
        
        # Add predictions saving parameters
        if [ "$SAVE_PREDICTIONS" = "true" ]; then
            CMD="$CMD --save_predictions"
        fi
        
        # Add cross evaluation parameter if specified
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        else
            if [ "$TRAINING_DATA_CONFIG" = "matched" ]; then
                CMD="$CMD --matched"
            fi
        fi
        
        echo "Running command:"
        echo "$CMD"
        echo ""
        echo "Output will be saved to: $CUSTOM_LOG_DIR"
        echo "Fairness metrics will be computed for: $FAIRNESS_ATTRIBUTES"
        
        # Execute the command
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK testing with fairness evaluation completed successfully!"
            echo "Results saved to: $CUSTOM_LOG_DIR"
            
            # Find the most recent results file
            RESULTS_PATTERN="$CUSTOM_LOG_DIR/lightning_logs/*/test_set_results.yaml"
            LATEST_RESULTS=$(ls -t $RESULTS_PATTERN 2>/dev/null | head -1)
            
            if [ -f "$LATEST_RESULTS" ]; then
                echo "Fairness metrics generated successfully!"
                
                # Display detailed fairness results
                display_fairness_results "$LATEST_RESULTS" "$SEED" "$TASK"
                
                # Save a summary for this seed
                SUMMARY_FILE="$CUSTOM_LOG_DIR/fairness_summary_seed_${SEED}.txt"
                echo "Fairness Summary for M3Care - Task: $TASK, Seed: $SEED" > "$SUMMARY_FILE"
                echo "Generated on: $(date)" >> "$SUMMARY_FILE"
                echo "=======================================" >> "$SUMMARY_FILE"
                
                # Extract key metrics to summary
                echo "" >> "$SUMMARY_FILE"
                echo "PRAUC Gaps:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*gap:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "" >> "$SUMMARY_FILE"
                echo "Worst-case Performance:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*worst_case:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "Fairness summary saved to: $SUMMARY_FILE"
                
            else
                echo "Warning: Results file not found at expected location"
                echo "Expected pattern: $RESULTS_PATTERN"
            fi
        else
            echo "Error: Seed $SEED for task $TASK testing failed!"
            exit 1
        fi
        
        echo ""
        echo "-----------------------------------"
    done
    
    # Task completion summary
    echo ""
    if [ -n "$CROSS_EVAL" ]; then
        echo "All 3 seeds cross evaluation with fairness analysis completed for task $TASK on fold $FOLD!"
        echo "Cross evaluation type: $CROSS_EVAL"
    else
        echo "All 3 seeds testing with fairness analysis completed for task $TASK on fold $FOLD!"
    fi
    echo "Task $TASK results saved to: $ROOT_DIR/$TASK/"
    echo "Fairness evaluation results included in test_set_results.yaml files"
    echo "Individual fairness summaries saved as fairness_summary_seed_*.txt"
    if [ "$SAVE_PREDICTIONS" = "true" ]; then
        echo "Test predictions and labels saved as test_predictions_fold${FOLD}_seed*.npz/csv"
    fi
    echo "=================================="
done
