#!/bin/bash

echo "Starting HealNet Model Testing with Fairness Evaluation..."

# Define experiment parameters (should match training parameters)
BATCH_SIZE=4                    # Fixed: Changed from 16 to 4 to match training
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=0                      # GPU device number
PRETRAINED=true
DROPOUT=0.2                # Fixed: Added missing dropout parameter

# Testing configuration with FAIRNESS ENABLED
USE_DEMOGRAPHICS=true                     # MUST be true for fairness evaluation
CROSS_EVAL=""              # Change this as needed

# ========== FAIRNESS EVALUATION CONFIGURATION ==========
COMPUTE_FAIRNESS=true                     # Enable fairness metrics computation
FAIRNESS_ATTRIBUTES="race gender age admission_type"    # Sensitive attributes for fairness evaluation
FAIRNESS_AGE_BINS="0 20 40 60 80"       # Age bins: 0-20, 20-40, 40-60, 60-80, 80+
FAIRNESS_INTERSECTIONAL=true           # Enable intersectional fairness analysis

# ========== PREDICTIONS SAVING CONFIGURATION ==========
SAVE_PREDICTIONS=true                    # Save test predictions and labels to experiment directory

echo "Fairness Evaluation Configuration:"
echo "   - Compute Fairness: $COMPUTE_FAIRNESS"
echo "   - Sensitive Attributes: $FAIRNESS_ATTRIBUTES"
echo "   - Age Bins: $FAIRNESS_AGE_BINS"
echo "   - Intersectional Analysis: $FAIRNESS_INTERSECTIONAL"
echo "   - Demographics Enabled: $USE_DEMOGRAPHICS"

echo "Predictions Saving Configuration:"
echo "   - Save Predictions: $SAVE_PREDICTIONS"
echo "   - Save Directory: Experiment directory (automatic)"

# HealNet specific parameters (should match training)
N_MODALITIES=2             # Number of modalities (EHR and CXR)
DEPTH=3                    # Number of fusion layers
LATENT_CHANNELS=256        # Number of latent tokens
LATENT_DIM=256            # Dimension of latent tokens
CROSS_HEADS=4             # Number of cross-attention heads
LATENT_HEADS=4            # Number of self-attention heads
CROSS_DIM_HEAD=64         # Dimension of each cross-attention head
LATENT_DIM_HEAD=64        # Dimension of each self-attention head
SELF_PER_CROSS_ATTN=1     # Self-attention layers per cross-attention
WEIGHT_TIE_LAYERS=true    # Whether to share weights across layers
SNN=true                  # Whether to use self-normalizing networks
FOURIER_ENCODE_DATA=true  # Whether to use Fourier positional encoding
NUM_FREQ_BANDS=2          # Number of frequency bands
MAX_FREQ=10.0             # Maximum frequency for encoding
FINAL_CLASSIFIER_HEAD=true # Whether to add final classification head
ATTN_DROPOUT=0.2          # Dropout rate for attention layers
FF_DROPOUT=0.2            # Dropout rate for feed-forward layers
# DROPOUT=0.2               # General dropout rate - moved to top

# Seeds configuration (should match training seeds)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("los")

# Function to find checkpoint file dynamically
find_checkpoint() {
    local checkpoint_dir="$1"
    local checkpoint_file=$(find "$checkpoint_dir" -name "*.ckpt" -type f | head -1)
    echo "$checkpoint_file"
}

# Function to determine training data config for checkpoint path
get_training_data_config() {
    local cross_eval="$1"
    local task="$2"
    local seed="$3"
    
    if [ "$cross_eval" = "matched_to_full" ]; then
        echo "matched"
    elif [ "$cross_eval" = "full_to_matched" ]; then
        echo "full"
    else
        # For normal testing, automatically detect available checkpoint
        # Try full first (since that's what we have), then matched
        local base_path="../experiments/healnet/$task/lightning_logs"
        local full_dir="${base_path}/HEALNET-model_healnet-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${seed}-pretrained_True-data_config_full"
        local matched_dir="${base_path}/HEALNET-model_healnet-task_${task}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${seed}-pretrained_True-data_config_matched"
        
        if [ -d "$full_dir" ]; then
            echo "full"
        elif [ -d "$matched_dir" ]; then
            echo "matched"
        else
            echo "full"  # Default fallback to full since that's what we have
        fi
    fi
}

# Function to extract and display fairness metrics
display_fairness_results() {
    local results_file="$1"
    local seed="$2"
    local task="$3"
    
    if [ -f "$results_file" ]; then
        echo ""
        echo "========== FAIRNESS RESULTS SUMMARY (Seed $seed, Task $task) =========="
        
        # Extract PRAUC gap metrics
        if [ "$task" = "los" ]; then
            METRIC_NAME="accuracy"
            METRIC_DISPLAY="ACC"
        else
            METRIC_NAME="PRAUC"
            METRIC_DISPLAY="PRAUC"
        fi
        echo "${METRIC_DISPLAY} Gap Analysis:"
        grep -E "fairness.*${METRIC_NAME}.*gap:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr ${METRIC_DISPLAY} gap" "$value"
        done
        
        # Extract worst-case performance
        echo ""
        echo "Worst-Case Performance:"
        grep -E "fairness.*${METRIC_NAME}.*worst_case:" "$results_file" | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            printf "   %-20s: %8s\n" "$attr worst ${METRIC_DISPLAY}" "$value"
        done
        
        # Extract group-wise performance
        echo ""
        echo "Group-wise ${METRIC_DISPLAY} Performance:"
        grep -E "fairness.*/[^/]*/${METRIC_NAME}:" "$results_file" | head -10 | while read line; do
            metric=$(echo "$line" | cut -d':' -f1 | sed 's/^[[:space:]]*//')
            value=$(echo "$line" | cut -d':' -f2 | sed 's/^[[:space:]]*//')
            attr=$(echo "$metric" | cut -d'/' -f2)
            group=$(echo "$metric" | cut -d'/' -f3)
            printf "   %-15s %-15s: %8s\n" "$attr" "$group" "$value"
        done
        
        echo "=================================================================="
    else
        echo "Warning: Results file not found: $results_file"
    fi
}

# Create root directory name
if [ -n "$CROSS_EVAL" ]; then
    ROOT_DIR="../experiments_fairness/healnet-${CROSS_EVAL}-fairness"
else
    ROOT_DIR="../experiments_fairness/healnet-fairness"
fi

echo "All results will be saved under: $ROOT_DIR"

for TASK in "${TASKS[@]}"
do
    if [ -n "$CROSS_EVAL" ]; then
        echo ""
        echo "Testing HealNet model for task: $TASK with cross evaluation: $CROSS_EVAL and fairness evaluation"
    else
        echo ""
        echo "Testing HealNet model for task: $TASK with fairness evaluation (normal test)"
    fi
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=10
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo ""
        echo "Testing with seed $SEED for task $TASK on fold $FOLD..."
        
        # Determine the training data config for finding checkpoint
        TRAINING_DATA_CONFIG=$(get_training_data_config "$CROSS_EVAL" "$TASK" "$SEED")
        echo "Using training data config: $TRAINING_DATA_CONFIG"
        
        # Build checkpoint directory path - Fixed: Added dropout parameter
        CHECKPOINT_DIR="../experiments/healnet/$TASK/lightning_logs/HEALNET-model_healnet-task_${TASK}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-dropout_${DROPOUT}-seed_${SEED}-pretrained_True-data_config_${TRAINING_DATA_CONFIG}/checkpoints"
        
        # Find checkpoint file
        CHECKPOINT_FILE=$(find_checkpoint "$CHECKPOINT_DIR")
        
        if [ ! -f "$CHECKPOINT_FILE" ]; then
            echo "Error: Checkpoint not found in $CHECKPOINT_DIR"
            echo "Available directories for seed $SEED:"
            ls -la "../experiments/healnet/$TASK/lightning_logs/" | grep "seed_${SEED}" || echo "No directories found for seed $SEED"
            
            # Try to find any available checkpoint for this seed
            echo "Searching for any available checkpoint for seed $SEED..."
            ALTERNATIVE_DIRS=$(ls -d "../experiments_los-m-m/healnet/$TASK/lightning_logs/"*"seed_${SEED}"* 2>/dev/null || echo "")
            
            if [ -n "$ALTERNATIVE_DIRS" ]; then
                for alt_dir in $ALTERNATIVE_DIRS; do
                    alt_checkpoint_dir="$alt_dir/checkpoints"
                    alt_checkpoint=$(find_checkpoint "$alt_checkpoint_dir")
                    if [ -f "$alt_checkpoint" ]; then
                        echo "Found alternative checkpoint: $alt_checkpoint"
                        CHECKPOINT_FILE="$alt_checkpoint"
                        if [[ "$alt_dir" == *"data_config_matched"* ]]; then
                            TRAINING_DATA_CONFIG="matched"
                        elif [[ "$alt_dir" == *"data_config_full"* ]]; then
                            TRAINING_DATA_CONFIG="full"
                        fi
                        echo "Updated training data config to: $TRAINING_DATA_CONFIG"
                        break
                    fi
                done
            fi
            
            if [ ! -f "$CHECKPOINT_FILE" ]; then
                echo "No valid checkpoint found for seed $SEED, skipping..."
                continue
            fi
        fi
        
        echo "Found checkpoint: $CHECKPOINT_FILE"
        echo "Using data config: $TRAINING_DATA_CONFIG"
        
        # Create custom output directory
        CUSTOM_LOG_DIR="$ROOT_DIR/$TASK"
        
        # Build base command with fairness options
        CMD="python ../main.py \
            --model healnet \
            --mode test \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --n_modalities $N_MODALITIES \
            --depth $DEPTH \
            --latent_channels $LATENT_CHANNELS \
            --latent_dim $LATENT_DIM \
            --cross_heads $CROSS_HEADS \
            --latent_heads $LATENT_HEADS \
            --cross_dim_head $CROSS_DIM_HEAD \
            --latent_dim_head $LATENT_DIM_HEAD \
            --self_per_cross_attn $SELF_PER_CROSS_ATTN \
            --num_freq_bands $NUM_FREQ_BANDS \
            --max_freq $MAX_FREQ \
            --attn_dropout $ATTN_DROPOUT \
            --ff_dropout $FF_DROPOUT \
            --dropout $DROPOUT \
            --gpu $GPU \
            --checkpoint_path $CHECKPOINT_FILE \
            --log_dir $CUSTOM_LOG_DIR"
        
        # Add conditional parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$WEIGHT_TIE_LAYERS" = "true" ]; then
            CMD="$CMD --weight_tie_layers"
        fi
        
        if [ "$SNN" = "true" ]; then
            CMD="$CMD --snn"
        fi
        
        if [ "$FOURIER_ENCODE_DATA" = "true" ]; then
            CMD="$CMD --fourier_encode_data"
        fi
        
        if [ "$FINAL_CLASSIFIER_HEAD" = "true" ]; then
            CMD="$CMD --final_classifier_head"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        # Add fairness evaluation parameters
        if [ "$COMPUTE_FAIRNESS" = "true" ]; then
            CMD="$CMD --compute_fairness"
            
            if [ -n "$FAIRNESS_ATTRIBUTES" ]; then
                CMD="$CMD --fairness_attributes $FAIRNESS_ATTRIBUTES"
            fi
            
            if [ -n "$FAIRNESS_AGE_BINS" ]; then
                CMD="$CMD --fairness_age_bins $FAIRNESS_AGE_BINS"
            fi
            
            if [ "$FAIRNESS_INTERSECTIONAL" = "true" ]; then
                CMD="$CMD --fairness_intersectional"
            fi
        fi
        
        # Add predictions saving parameters
        if [ "$SAVE_PREDICTIONS" = "true" ]; then
            CMD="$CMD --save_predictions"
        fi
        
        # Add cross evaluation parameter if specified
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        else
            if [ "$TRAINING_DATA_CONFIG" = "matched" ]; then
                CMD="$CMD --matched"
            fi
        fi
        
        echo "Running command:"
        echo "$CMD"
        echo ""
        echo "Output will be saved to: $CUSTOM_LOG_DIR"
        echo "Fairness metrics will be computed for: $FAIRNESS_ATTRIBUTES"
        
        # Execute the command
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK testing with fairness evaluation completed successfully!"
            echo "Results saved to: $CUSTOM_LOG_DIR"
            
            # Find the most recent results file
            RESULTS_PATTERN="$CUSTOM_LOG_DIR/lightning_logs/*/test_set_results.yaml"
            LATEST_RESULTS=$(ls -t $RESULTS_PATTERN 2>/dev/null | head -1)
            
            if [ -f "$LATEST_RESULTS" ]; then
                echo "Fairness metrics generated successfully!"
                
                # Display detailed fairness results
                display_fairness_results "$LATEST_RESULTS" "$SEED" "$TASK"
                
                # Save a summary for this seed
                SUMMARY_FILE="$CUSTOM_LOG_DIR/fairness_summary_seed_${SEED}.txt"
                echo "Fairness Summary for HealNet - Task: $TASK, Seed: $SEED" > "$SUMMARY_FILE"
                echo "Generated on: $(date)" >> "$SUMMARY_FILE"
                echo "=======================================" >> "$SUMMARY_FILE"
                
                # Extract key metrics to summary
                echo "" >> "$SUMMARY_FILE"
                echo "PRAUC Gaps:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*gap:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "" >> "$SUMMARY_FILE"
                echo "Worst-case Performance:" >> "$SUMMARY_FILE"
                grep -E "fairness.*PRAUC.*worst_case:" "$LATEST_RESULTS" >> "$SUMMARY_FILE"
                
                echo "Fairness summary saved to: $SUMMARY_FILE"
                
            else
                echo "Warning: Results file not found at expected location"
                echo "Expected pattern: $RESULTS_PATTERN"
            fi
        else
            echo "Error: Seed $SEED for task $TASK testing failed!"
            exit 1
        fi
        
        echo ""
        echo "-----------------------------------"
    done
    
    # Task completion summary
    echo ""
    if [ -n "$CROSS_EVAL" ]; then
        echo "All 3 seeds cross evaluation with fairness analysis completed for task $TASK on fold $FOLD!"
        echo "Cross evaluation type: $CROSS_EVAL"
    else
        echo "All 3 seeds testing with fairness analysis completed for task $TASK on fold $FOLD!"
    fi
    echo "Task $TASK results saved to: $ROOT_DIR/$TASK/"
    echo "Fairness evaluation results included in test_set_results.yaml files"
    echo "Individual fairness summaries saved as fairness_summary_seed_*.txt"
    if [ "$SAVE_PREDICTIONS" = "true" ]; then
        echo "Test predictions and labels saved as test_predictions_fold${FOLD}_seed*.npz/csv"
    fi
    echo "=================================="
done
