#!/bin/bash

echo "Starting DrFuse Model Testing..."

# Define experiment parameters (should match training parameters)
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=0                      # GPU device number
PRETRAINED=true

# Testing configuration
USE_DEMOGRAPHICS=false
# Set CROSS_EVAL to enable cross evaluation:
# - "" or empty: normal test (same data as training)
# - "matched_to_full": test models trained on matched data with full data
# - "full_to_matched": test models trained on full data with matched data
CROSS_EVAL="matched_to_full"              # Change this as needed

# DrFuse specific parameters (should match training)
EHR_ENCODER=transformer    # Options: 'lstm', 'transformer'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16'
HIDDEN_SIZE=256            # Hidden dimension       
EHR_DROPOUT=0.2            # EHR encoder dropout
EHR_N_HEAD=4               # Number of attention heads for transformer
EHR_N_LAYERS_FEAT=1        # Number of feature extraction layers
EHR_N_LAYERS_SHARED=1      # Number of shared representation layers
EHR_N_LAYERS_DISTINCT=1    # Number of distinct representation layers
PRETRAINED=true            # Use pretrained CXR encoder
DISENTANGLE_LOSS=jsd       # Options: 'mse', 'jsd', 'adc', 'triplet'
ATTN_FUSION=mid          # Options: 'early', 'mid', 'late'
LOGIT_AVERAGE=true        # Use logit averaging for shared features

# Loss weighting parameters
LAMBDA_PRED_SHARED=1.9227323284552051     # Shared prediction loss weight
LAMBDA_PRED_EHR=1.145183509066745        # EHR prediction loss weight
LAMBDA_PRED_CXR=1.0464601774513893         # CXR prediction loss weight 
LAMBDA_DISENTANGLE_SHARED=1.854051142929651  # Shared disentanglement loss weight
LAMBDA_DISENTANGLE_EHR=1.4572712717542777     # EHR disentanglement loss weight
LAMBDA_DISENTANGLE_CXR=0.6598161299236125     # CXR disentanglement loss weight
LAMBDA_ATTN_AUX=1.6906223588695217        # Attention auxiliary loss weight 

# Seeds configuration (should match training seeds)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("phenotype")

# Function to find checkpoint file dynamically
find_checkpoint() {
    local checkpoint_dir="$1"
    local checkpoint_file=$(find "$checkpoint_dir" -name "*.ckpt" -type f | head -1)
    echo "$checkpoint_file"
}

# Function to determine training data config for checkpoint path
get_training_data_config() {
    local cross_eval="$1"
    if [ "$cross_eval" = "matched_to_full" ]; then
        echo "matched"
    elif [ "$cross_eval" = "full_to_matched" ]; then
        echo "full"
    else
        # For normal testing, determine based on current settings
        if [ "$MATCHED" = "true" ]; then
            echo "matched"
        else
            echo "full"
        fi
    fi
}

# 创建根目录名称：drfuse-{cross_eval}
ROOT_DIR="../experiments/drfuse-${CROSS_EVAL}"

echo "📁 All results will be saved under: $ROOT_DIR"

for TASK in "${TASKS[@]}"
do
    if [ -n "$CROSS_EVAL" ]; then
        echo "Testing DrFuse model for task: $TASK with cross evaluation: $CROSS_EVAL"
    else
        echo "Testing DrFuse model for task: $TASK (normal test)"
    fi
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=498
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=7
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Testing with seed $SEED for task $TASK on fold $FOLD..."
        
        # Determine the training data config for finding checkpoint
        TRAINING_DATA_CONFIG=$(get_training_data_config "$CROSS_EVAL")
        
        # Build checkpoint directory path (still use the full path to find existing checkpoints)
        CHECKPOINT_DIR="../experiments-m-m/drfuse/$TASK/lightning_logs/DRFUSE-model_drfuse-task_${TASK}-fold_${FOLD}-batch_size_${BATCH_SIZE}-lr_${LR}-patience_${PATIENCE}-epochs_${EPOCHS}-seed_${SEED}-pretrained_True-data_config_${TRAINING_DATA_CONFIG}/checkpoints"
        
        # Find checkpoint file
        CHECKPOINT_FILE=$(find_checkpoint "$CHECKPOINT_DIR")
        
        if [ ! -f "$CHECKPOINT_FILE" ]; then
            echo "❌ Error: Checkpoint not found in $CHECKPOINT_DIR"
            echo "Available directories:"
            ls -la "../experiments/drfuse/$TASK/lightning_logs/" | grep "seed_${SEED}" || echo "No directories found for seed $SEED"
            continue
        fi
        
        echo "✅ Found checkpoint: $CHECKPOINT_FILE"
        
        # 创建自定义输出目录：root_dir/task_name
        CUSTOM_LOG_DIR="$ROOT_DIR/$TASK"
        
        # Build base command with custom log_dir
        CMD="python ../main.py \
            --model drfuse \
            --mode test \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --hidden_size $HIDDEN_SIZE \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --ehr_dropout $EHR_DROPOUT \
            --ehr_n_head $EHR_N_HEAD \
            --ehr_n_layers_feat $EHR_N_LAYERS_FEAT \
            --ehr_n_layers_shared $EHR_N_LAYERS_SHARED \
            --ehr_n_layers_distinct $EHR_N_LAYERS_DISTINCT \
            --disentangle_loss $DISENTANGLE_LOSS \
            --attn_fusion $ATTN_FUSION \
            --lambda_pred_shared $LAMBDA_PRED_SHARED \
            --lambda_pred_ehr $LAMBDA_PRED_EHR \
            --lambda_pred_cxr $LAMBDA_PRED_CXR \
            --lambda_disentangle_shared $LAMBDA_DISENTANGLE_SHARED \
            --lambda_disentangle_ehr $LAMBDA_DISENTANGLE_EHR \
            --lambda_disentangle_cxr $LAMBDA_DISENTANGLE_CXR \
            --lambda_attn_aux $LAMBDA_ATTN_AUX \
            --gpu $GPU \
            --checkpoint_path $CHECKPOINT_FILE \
            --log_dir $CUSTOM_LOG_DIR"
        
        # Add conditional parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$LOGIT_AVERAGE" = "true" ]; then
            CMD="$CMD --logit_average"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        # Add cross evaluation parameter if specified
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        else
            # For normal testing, add matched parameter if needed
            if [ "$TRAINING_DATA_CONFIG" = "matched" ]; then
                CMD="$CMD --matched"
            fi
        fi
        
        echo "Running command: $CMD"
        echo "Output will be saved to: $CUSTOM_LOG_DIR"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK testing completed successfully!"
            echo "Results saved to: $CUSTOM_LOG_DIR"
        else
            echo "Error: Seed $SEED for task $TASK testing failed!"
            exit 1
        fi
        
        echo "-----------------------------------"
    done
    
    if [ -n "$CROSS_EVAL" ]; then
        echo "All 3 seeds cross evaluation completed for task $TASK on fold $FOLD!"
        echo "Cross evaluation type: $CROSS_EVAL"
    else
        echo "All 3 seeds testing completed for task $TASK on fold $FOLD!"
    fi
    echo "Task $TASK results saved to: $ROOT_DIR/$TASK/"
    echo "=================================="
done

echo "All testing tasks completed successfully!"
echo "All results can be found in: $ROOT_DIR/"
echo "Directory structure:"
echo "   $ROOT_DIR/"
for TASK in "${TASKS[@]}"; do
    echo "   ├── $TASK/"
    echo "   │   └── lightning_logs/"
done

if [ -n "$CROSS_EVAL" ]; then
    echo "Cross evaluation results saved with cross_${CROSS_EVAL} data configuration"
else
    echo "Normal testing results saved"
fi
