#!/bin/bash

# Script to collect statistics from a specific experiment directory using different seeds
# Usage: ./scripts2/collect_exp_statistics.sh [experiment_path] [output_file]
# ./scripts2/collect_exp_statistics.sh experiments_original_setting/drfuse/phenotype/lightning_logs
# ./scripts2/collect_exp_statistics.sh workspace/benchmark-multimodal-clinical-learners/experiments/latefusion_uniloss/mortality/lightning_logs
# ./scripts2/collect_exp_statistics.sh experiments_no_weight/shaspec/mortality/lightning_logs/lightning_logs

# Color codes for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color

# Function to print colored output
print_info() {
    echo -e "${BLUE}[INFO]${NC} $1"
}

print_success() {
    echo -e "${GREEN}[SUCCESS]${NC} $1"
}

print_warning() {
    echo -e "${YELLOW}[WARNING]${NC} $1"
}

print_error() {
    echo -e "${RED}[ERROR]${NC} $1"
}

# Function to show usage
show_usage() {
    cat << EOF
Usage: $0 [EXPERIMENT_PATH] [OUTPUT_FILE]

Collect statistics from a specific experiment directory containing seed-based results.

Arguments:
    EXPERIMENT_PATH    Path to experiment directory (e.g., experiments/drfuse/mortality/lightning_logs/exp1)
    OUTPUT_FILE        Output file for results (optional, will be auto-generated if not provided)

Examples:
    # Basic usage
    $0 experiments/drfuse/mortality/lightning_logs/exp1
    
    # With custom output file
    $0 experiments/drfuse/mortality/lightning_logs/exp1 my_results.yaml

EOF
}

# Parse arguments
if [[ $# -eq 0 || "$1" == "-h" || "$1" == "--help" ]]; then
    show_usage
    exit 0
fi

EXPERIMENT_PATH="$1"
OUTPUT_FILE="$2"

# Check if experiment path exists
if [[ ! -d "$EXPERIMENT_PATH" ]]; then
    print_error "Experiment directory '$EXPERIMENT_PATH' does not exist"
    exit 1
fi

# Auto-generate output file if not provided
if [[ -z "$OUTPUT_FILE" ]]; then
    OUTPUT_FILE="$EXPERIMENT_PATH/seed_statistics.yaml"
fi

print_info "Experiment Path: $EXPERIMENT_PATH"
print_info "Output File: $OUTPUT_FILE"

# Find all seed directories
SEED_DIRS=($(find "$EXPERIMENT_PATH" -name "*seed_*" -type d | sort))

if [[ ${#SEED_DIRS[@]} -eq 0 ]]; then
    print_error "No seed directories found in $EXPERIMENT_PATH"
    exit 1
fi

print_info "Found ${#SEED_DIRS[@]} seed directories"

# Create a temporary Python script to collect statistics
TEMP_SCRIPT=$(mktemp)
cat > "$TEMP_SCRIPT" << 'EOF'
import os
import yaml
import pandas as pd
import numpy as np
import sys
from pathlib import Path

def load_results_from_yaml(result_file):
    """Load results from test_set_results.yaml file"""
    try:
        with open(result_file, 'r') as f:
            results = yaml.safe_load(f)
        return results
    except Exception as e:
        print(f"Error loading {result_file}: {e}")
        return None

def aggregate_metrics(all_results):
    """Calculate mean and standard deviation for all metrics across seeds"""
    if not all_results:
        return None
    
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(all_results).T  # Transpose so seeds are rows
    
    # Calculate statistics
    stats = {
        'mean': df.mean(),
        'std': df.std(),
        'min': df.min(),
        'max': df.max(),
        'count': df.count()
    }
    
    return stats

def format_results_table(stats):
    """Format results into a nice table with mean±std format"""
    if stats is None:
        return "No valid results found"
    
    # Group metrics by type
    metric_groups = {
        'Overall': [],
        'ACC': [],
        'AUROC': [],
        'F1': [],
        'PRAUC': [],
        'Precision': [],
        'Recall': [],
        'Specificity': []
    }
    
    # Categorize metrics
    for metric in stats['mean'].index:
        if metric.startswith('overall/'):
            metric_groups['Overall'].append(metric)
        else:
            metric_found = False
            for group in ['ACC', 'AUROC', 'F1', 'PRAUC', 'Precision', 'Recall', 'Specificity']:
                if metric.startswith(group + '/'):
                    metric_groups[group].append(metric)
                    metric_found = True
                    break
            
            # If not found in standard groups, create a new group
            if not metric_found:
                group_name = metric.split('/')[0] if '/' in metric else 'Other'
                if group_name not in metric_groups:
                    metric_groups[group_name] = []
                metric_groups[group_name].append(metric)
    
    # Create formatted output
    output = []
    output.append("=" * 80)
    output.append("EXPERIMENT RESULTS SUMMARY (SEED-BASED)")
    output.append("=" * 80)
    
    for group, metrics in metric_groups.items():
        if not metrics:
            continue
            
        output.append(f"\n{group} METRICS:")
        output.append("-" * 50)
        
        for metric in sorted(metrics):
            mean_val = stats['mean'][metric]
            std_val = stats['std'][metric]
            count = stats['count'][metric]
            
            # Format metric name for display
            display_name = metric.split('/')[-1] if '/' in metric else metric
            
            # Format in the requested mean±std format
            formatted_result = f"{mean_val:.4f}±{std_val:.4f}"
            output.append(f"{display_name:<45} {formatted_result} (n={count})")
    
    return "\n".join(output)

def save_detailed_results(stats, output_file, seeds_info):
    """Save detailed results to a file"""
    if stats is None:
        return
    
    # Ensure the output directory exists
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    # Create a comprehensive results dictionary
    detailed_results = {
        'experiment_type': 'seed_based',
        'seeds_used': seeds_info,
        'metrics': {}
    }
    
    for metric in stats['mean'].index:
        mean_val = stats['mean'][metric]
        std_val = stats['std'][metric]
        
        detailed_results['metrics'][metric] = {
            'mean': float(mean_val),
            'std': float(std_val),
            'formatted': f"{mean_val:.4f}±{std_val:.4f}",
            'min': float(stats['min'][metric]),
            'max': float(stats['max'][metric]),
            'count': int(stats['count'][metric])
        }
    
    # Save to YAML file
    with open(output_file, 'w') as f:
        yaml.dump(detailed_results, f, default_flow_style=False, sort_keys=True)
    
    print(f"Detailed results saved to: {output_file}")

def save_formatted_summary(stats, output_file):
    """Save a clean formatted summary for easy copying"""
    if stats is None:
        return
    
    summary_file = output_file.replace('.yaml', '_formatted.txt')
    
    with open(summary_file, 'w') as f:
        f.write("EXPERIMENT RESULTS SUMMARY (mean±std format)\n")
        f.write("=" * 60 + "\n\n")
        
        # Group metrics by type
        metric_groups = {
            'Overall': [],
            'ACC': [],
            'AUROC': [],
            'F1': [],
            'PRAUC': [],
            'Precision': [],
            'Recall': [],
            'Specificity': []
        }
        
        # Categorize metrics
        for metric in stats['mean'].index:
            if metric.startswith('overall/'):
                metric_groups['Overall'].append(metric)
            else:
                metric_found = False
                for group in ['ACC', 'AUROC', 'F1', 'PRAUC', 'Precision', 'Recall', 'Specificity']:
                    if metric.startswith(group + '/'):
                        metric_groups[group].append(metric)
                        metric_found = True
                        break
                
                if not metric_found:
                    group_name = metric.split('/')[0] if '/' in metric else 'Other'
                    if group_name not in metric_groups:
                        metric_groups[group_name] = []
                    metric_groups[group_name].append(metric)
        
        for group, metrics in metric_groups.items():
            if not metrics:
                continue
                
            f.write(f"{group} METRICS:\n")
            f.write("-" * 30 + "\n")
            
            for metric in sorted(metrics):
                mean_val = stats['mean'][metric]
                std_val = stats['std'][metric]
                display_name = metric.split('/')[-1] if '/' in metric else metric
                formatted_result = f"{mean_val:.4f}±{std_val:.4f}"
                f.write(f"{display_name}: {formatted_result}\n")
            
            f.write("\n")
    
    print(f"Formatted summary saved to: {summary_file}")

# Main execution
if __name__ == "__main__":
    experiment_path = sys.argv[1]
    output_file = sys.argv[2]
    
    # Find all seed directories
    seed_dirs = []
    for item in os.listdir(experiment_path):
        item_path = os.path.join(experiment_path, item)
        if os.path.isdir(item_path) and "seed_" in item:
            seed_dirs.append(item_path)
    
    seed_dirs.sort()
    
    if not seed_dirs:
        print("No seed directories found!")
        sys.exit(1)
    
    print(f"Found {len(seed_dirs)} seed directories")
    
    # Load results from all seeds
    all_results = {}
    seeds_info = []
    
    for seed_dir in seed_dirs:
        # Extract seed number from directory name  
        seed_num = None
        dir_name = os.path.basename(seed_dir)
        
        # Try different patterns to extract seed number
        import re
        seed_match = re.search(r'seed_(\d+)', dir_name)
        if seed_match:
            seed_num = int(seed_match.group(1))
        else:
            print(f"Warning: Could not extract seed number from {seed_dir}")
            continue
        
        seeds_info.append(seed_num)
        
        # Look for test_set_results.yaml
        result_file = os.path.join(seed_dir, "test_set_results.yaml")
        results = None
        
        if os.path.exists(result_file):
            print(f"Loading seed {seed_num}: {result_file}")
            results = load_results_from_yaml(result_file)
        else:
            print(f"Warning: No test_set_results.yaml found in {seed_dir}")
        
        if results:
            all_results[seed_num] = results
        else:
            print(f"Warning: No valid results found for seed {seed_num}")
    
    if not all_results:
        print("No valid results could be loaded!")
        sys.exit(1)
    
    print(f"Successfully loaded results from {len(all_results)} seeds")
    
    # Calculate aggregate statistics
    stats = aggregate_metrics(all_results)
    
    # Display results
    print("\n" + format_results_table(stats))
    
    # Save detailed results
    save_detailed_results(stats, output_file, sorted(seeds_info))
    
    # Save formatted summary
    save_formatted_summary(stats, output_file)
    
    # Also save a summary CSV for easy analysis
    if stats:
        summary_df = pd.DataFrame({
            'metric': stats['mean'].index,
            'mean': stats['mean'].values,
            'std': stats['std'].values,
            'formatted': [f"{mean:.4f}±{std:.4f}" for mean, std in zip(stats['mean'].values, stats['std'].values)],
            'min': stats['min'].values,
            'max': stats['max'].values,
            'count': stats['count'].values
        })
        
        csv_file = output_file.replace('.yaml', '.csv')
        summary_df.to_csv(csv_file, index=False)
        print(f"Summary CSV saved to: {csv_file}")

EOF

# Execute the Python script
print_info "Starting statistics collection..."
if python3 "$TEMP_SCRIPT" "$EXPERIMENT_PATH" "$OUTPUT_FILE"; then
    print_success "Statistics collection completed successfully!"
else
    print_error "Statistics collection failed!"
    rm -f "$TEMP_SCRIPT"
    exit 1
fi

# Clean up
rm -f "$TEMP_SCRIPT" 