 #!/bin/bash

echo "Starting 5-seed training for DrFuse model on fold 1..."

# Define experiment parameters
BATCH_SIZE=16
LR=0.0001
PATIENCE=10
EPOCHS=50
FOLD=1                     # Fixed fold 1
GPU=3                      # GPU device number
MATCHED=false
USE_DEMOGRAPHICS=false
CROSS_EVAL=""              # Set to "matched_to_full" or "full_to_matched" if needed

# DrFuse specific parameters
EHR_ENCODER=transformer    # Options: 'lstm', 'transformer'
CXR_ENCODER=resnet50       # Options: 'resnet50', 'vit_b_16'
HIDDEN_SIZE=256            # Hidden dimension       
EHR_DROPOUT=0.2            # EHR encoder dropout
EHR_N_HEAD=4               # Number of attention heads for transformer
EHR_N_LAYERS_FEAT=1        # Number of feature extraction layers
EHR_N_LAYERS_SHARED=1      # Number of shared representation layers
EHR_N_LAYERS_DISTINCT=1    # Number of distinct representation layers
PRETRAINED=true            # Use pretrained CXR encoder
DISENTANGLE_LOSS=jsd       # Options: 'mse', 'jsd', 'adc', 'triplet'
ATTN_FUSION=mid          # Options: 'early', 'mid', 'late'
LOGIT_AVERAGE=true        # Use logit averaging for shared features

# Loss weighting parameters
LAMBDA_PRED_SHARED=0.05589422583241736     # Shared prediction loss weight
LAMBDA_PRED_EHR=1.2271897893716792        # EHR prediction loss weight
LAMBDA_PRED_CXR=0.02406194738723764         # CXR prediction loss weight 
LAMBDA_DISENTANGLE_SHARED=0.011549744023618514  # Shared disentanglement loss weight
LAMBDA_DISENTANGLE_EHR=1.9845010029895234     # EHR disentanglement loss weight
LAMBDA_DISENTANGLE_CXR=1.238788204159156     # CXR disentanglement loss weight
LAMBDA_ATTN_AUX=1.0543015739141945        # Attention auxiliary loss weight 

# Seeds configuration (including 42)
SEEDS=(42 123 1234)

# Tasks configuration
TASKS=("los")

for TASK in "${TASKS[@]}"
do
    echo "Training DrFuse model for task: $TASK"
    
    # Set task-specific parameters
    if [ "$TASK" = "phenotype" ]; then
        NUM_CLASSES=25
        INPUT_DIM=49
    elif [ "$TASK" = "mortality" ]; then
        NUM_CLASSES=1
        INPUT_DIM=49
    elif [ "$TASK" = "los" ]; then
        NUM_CLASSES=7
        INPUT_DIM=49
    fi
    
    for SEED in "${SEEDS[@]}"
    do
        echo "Training with seed $SEED for task $TASK on fold $FOLD..."
        
        # Build base command
        CMD="python ../main.py \
            --model drfuse \
            --mode train \
            --task $TASK \
            --fold $FOLD \
            --batch_size $BATCH_SIZE \
            --lr $LR \
            --patience $PATIENCE \
            --epochs $EPOCHS \
            --seed $SEED \
            --ehr_encoder $EHR_ENCODER \
            --cxr_encoder $CXR_ENCODER \
            --hidden_size $HIDDEN_SIZE \
            --ehr_dropout $EHR_DROPOUT \
            --ehr_n_head $EHR_N_HEAD \
            --ehr_n_layers_feat $EHR_N_LAYERS_FEAT \
            --ehr_n_layers_shared $EHR_N_LAYERS_SHARED \
            --ehr_n_layers_distinct $EHR_N_LAYERS_DISTINCT \
            --input_dim $INPUT_DIM \
            --num_classes $NUM_CLASSES \
            --disentangle_loss $DISENTANGLE_LOSS \
            --attn_fusion $ATTN_FUSION \
            --lambda_pred_shared $LAMBDA_PRED_SHARED \
            --lambda_pred_ehr $LAMBDA_PRED_EHR \
            --lambda_pred_cxr $LAMBDA_PRED_CXR \
            --lambda_disentangle_shared $LAMBDA_DISENTANGLE_SHARED \
            --lambda_disentangle_ehr $LAMBDA_DISENTANGLE_EHR \
            --lambda_disentangle_cxr $LAMBDA_DISENTANGLE_CXR \
            --lambda_attn_aux $LAMBDA_ATTN_AUX \
            --gpu $GPU"
        
        # Add DrFuse specific parameters
        if [ "$PRETRAINED" = "true" ]; then
            CMD="$CMD --pretrained"
        fi
        
        if [ "$LOGIT_AVERAGE" = "true" ]; then
            CMD="$CMD --logit_average"
        fi
        
        # Add conditional parameters
        if [ "$MATCHED" = "true" ]; then
            CMD="$CMD --matched"
        fi
        
        if [ "$USE_DEMOGRAPHICS" = "true" ]; then
            CMD="$CMD --use_demographics"
        fi
        
        if [ -n "$CROSS_EVAL" ]; then
            CMD="$CMD --cross_eval $CROSS_EVAL"
        fi
        
        echo "Running command: $CMD"
        eval $CMD
        
        if [ $? -eq 0 ]; then
            echo "Seed $SEED for task $TASK completed successfully!"
        else
            echo "Error: Seed $SEED for task $TASK failed!"
            exit 1
        fi
    done
    
    echo "All 5 seeds completed for task $TASK on fold $FOLD!"
    echo "Collecting and aggregating statistics..."
    
    # Collect statistics for multi-seed experiment
    echo "Attempting statistics collection for multi-seed experiment..."
    python ../collect_seed_statistics.py \
        --experiment_dir ../experiments \
        --model drfuse \
        --task $TASK \
        --fold $FOLD \
        --seeds ${SEEDS[@]} \
        --batch_size $BATCH_SIZE \
        --lr $LR \
        --patience $PATIENCE \
        --epochs $EPOCHS \
        --ehr_encoder $EHR_ENCODER \
        --cxr_encoder $CXR_ENCODER \
        --hidden_size $HIDDEN_SIZE \
        --ehr_dropout $EHR_DROPOUT \
        --ehr_n_head $EHR_N_HEAD \
        --ehr_n_layers_feat $EHR_N_LAYERS_FEAT \
        --ehr_n_layers_shared $EHR_N_LAYERS_SHARED \
        --ehr_n_layers_distinct $EHR_N_LAYERS_DISTINCT \
        --num_classes $NUM_CLASSES \
        --disentangle_loss $DISENTANGLE_LOSS \
        --attn_fusion $ATTN_FUSION \
        --output experiments/drfuse/$TASK/drfuse_${TASK}_fold${FOLD}_5seeds_statistics.yaml
    
    if [ $? -eq 0 ]; then
        echo "Statistics collection completed successfully for task $TASK!"
    else
        echo "Warning: Statistics collection failed for task $TASK. Using alternative method..."
        
        # Fallback: Use the direct experiment directory statistics collection
        if [ -d "experiments/drfuse/$TASK/lightning_logs" ]; then
            echo "Using direct experiment directory statistics collection..."
            
            # Create a simple statistics aggregation for different seeds
            echo "Creating seed-based statistics summary..."
            
            # Calculate basic statistics across seeds
            python -c "
import os
import yaml
import pandas as pd
import numpy as np
from pathlib import Path

# Define paths and parameters
exp_dir = Path('experiments/drfuse/$TASK/lightning_logs')
seeds = [${SEEDS[@]}]
fold = $FOLD

# Collect metrics from each seed
all_metrics = []
for seed in seeds:
    # Look for experiment directories with this seed
    seed_dirs = list(exp_dir.glob(f'*seed_{seed}*'))
    if not seed_dirs:
        # Fallback: look for any directory and check version files
        seed_dirs = list(exp_dir.glob('*'))
    
    for seed_dir in seed_dirs:
        metrics_file = seed_dir / 'metrics.csv'
        if metrics_file.exists():
            try:
                df = pd.read_csv(metrics_file)
                if 'test_epoch' in df.columns:
                    test_metrics = df.dropna(subset=['test_epoch'])
                    if not test_metrics.empty:
                        last_test = test_metrics.iloc[-1]
                        all_metrics.append({
                            'seed': seed,
                            'metrics': last_test.to_dict()
                        })
                        break
            except Exception as e:
                print(f'Error reading {metrics_file}: {e}')

# Calculate statistics
if all_metrics:
    # Extract metric names (excluding non-numeric columns)
    metric_names = set()
    for m in all_metrics:
        for key in m['metrics'].keys():
            if key not in ['epoch', 'test_epoch', 'step'] and isinstance(m['metrics'][key], (int, float)):
                metric_names.add(key)
    
    # Calculate mean and std for each metric
    statistics = {}
    for metric in metric_names:
        values = []
        for m in all_metrics:
            if metric in m['metrics'] and not pd.isna(m['metrics'][metric]):
                values.append(m['metrics'][metric])
        
        if values:
            statistics[metric] = {
                'mean': float(np.mean(values)),
                'std': float(np.std(values)),
                'min': float(np.min(values)),
                'max': float(np.max(values)),
                'count': len(values)
            }
    
    # Save results
    output_file = 'experiments/drfuse/$TASK/drfuse_${TASK}_fold${FOLD}_5seeds_statistics.yaml'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    with open(output_file, 'w') as f:
        yaml.dump({
            'experiment': 'drfuse_${TASK}_fold${FOLD}_5seeds',
            'fold': $FOLD,
            'seeds': seeds,
            'statistics': statistics
        }, f, default_flow_style=False)
    
    print(f'Statistics saved to {output_file}')
    print(f'Processed {len(all_metrics)} seed experiments')
else:
    print('No valid metrics found across seeds')
"
        fi
    fi
    
    echo "Task $TASK training and evaluation completed for fold $FOLD with 5 seeds!"
    echo "=================================="
done

echo "All tasks completed successfully!"
echo "Results can be found in experiments/drfuse/[task_name]/"
echo "Each seed experiment creates a separate checkpoint and log directory"
echo "Aggregated statistics are saved with '_5seeds_statistics.yaml' suffix"