#!/usr/bin/env python3
"""
Script to collect experiment statistics from experiments/ directory
Outputs separate CSV files for each task (mortality and phenotype)
"""

import os
import re
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import yaml

def extract_metrics_from_yaml(yaml_file: str) -> Dict[str, float]:
    """Extract performance metrics from test_set_results.yaml file"""
    metrics = {}
    
    try:
        with open(yaml_file, 'r', encoding='utf-8') as f:
            content = f.read()
            
        # Extract metrics using regex patterns from YAML content
        patterns = {
            'F1_macro': r'overall/F1_macro:\s*([0-9.]+)',
            'F1_weighted': r'overall/F1_weighted:\s*([0-9.]+)',
            'Precision_macro': r'overall/Precision_macro:\s*([0-9.]+)',
            'Precision_weighted': r'overall/Precision_weighted:\s*([0-9.]+)',
            'Recall_macro': r'overall/Recall_macro:\s*([0-9.]+)',
            'Recall_weighted': r'overall/Recall_weighted:\s*([0-9.]+)',
            'Specificity_macro': r'overall/Specificity_macro:\s*([0-9.]+)',
            'Specificity_weighted': r'overall/Specificity_weighted:\s*([0-9.]+)',
            'ACC': r'overall/ACC:\s*([0-9.]+)',
            'Kappa': r'overall/Kappa:\s*([0-9.]+)',
        }
        
        for metric_name, pattern in patterns.items():
            matches = re.findall(pattern, content)
            if matches:
                # Get the last occurrence (final evaluation result)
                metrics[metric_name] = float(matches[-1])
            
    except Exception as e:
        print(f"Error reading YAML file {yaml_file}: {e}")
    
    return metrics

def extract_task_from_path(log_file: str) -> str:
    """Extract task name from log file path"""
    try:
        # Look for task in path: experiments/model/task/...
        path_parts = log_file.split(os.sep)
        for i, part in enumerate(path_parts):
            if part in ['mortality', 'phenotype']:
                return part
        return "unknown"
    except:
        return "unknown"

def find_experiment_yaml_files(experiment_dir: str) -> List[Tuple[str, str]]:
    """Find all test_set_results.yaml files in the directory with their task info"""
    yaml_files = []
    
    try:
        for root, dirs, files in os.walk(experiment_dir):
            for file in files:
                if file == "test_set_results.yaml":
                    yaml_file_path = os.path.join(root, file)
                    task = extract_task_from_path(yaml_file_path)
                    yaml_files.append((yaml_file_path, task))
                    print(f"    Found YAML: {yaml_file_path} -> task: {task}")
    except Exception as e:
        print(f"Error walking directory {experiment_dir}: {e}")
    
    return yaml_files

def collect_model_statistics(experiment_dir: str) -> Dict[str, Dict]:
    """Collect statistics for all models in experiments directory, separated by task"""
    all_results = {}
    
    # Get all subdirectories (model directories)
    try:
        model_dirs = [d for d in os.listdir(experiment_dir) 
                     if os.path.isdir(os.path.join(experiment_dir, d))]
    except Exception as e:
        print(f"Error reading experiments directory: {e}")
        return all_results
    
    print(f"Found model directories: {model_dirs}")
    
    for model_dir in model_dirs:
        model_path = os.path.join(experiment_dir, model_dir)
        print(f"\nProcessing model: {model_dir}")
        print(f"  Model path: {model_path}")
        
        # Find all test_set_results.yaml files for this model
        yaml_files_with_tasks = find_experiment_yaml_files(model_path)
        print(f"  Found {len(yaml_files_with_tasks)} YAML files")
        
        if not yaml_files_with_tasks:
            print(f"  No YAML files found in {model_path}")
            continue
        
        # Group YAML files by task
        task_yamls = {}
        for yaml_file, task in yaml_files_with_tasks:
            if task not in task_yamls:
                task_yamls[task] = []
            task_yamls[task].append(yaml_file)
        
        # Process each task separately
        for task, yaml_files in task_yamls.items():
            print(f"    Task {task}: {len(yaml_files)} experiments")
            
            # Extract metrics from all YAML files for this task
            all_metrics = []
            for yaml_file in yaml_files:
                metrics = extract_metrics_from_yaml(yaml_file)
                if metrics:
                    all_metrics.append(metrics)
            
            if all_metrics:
                # Calculate statistics for each metric
                metric_stats = {}
                metric_names = ['F1_macro', 'F1_weighted', 'Precision_macro', 
                              'Precision_weighted', 'Recall_macro', 'Recall_weighted', 
                              'Specificity_macro', 'Specificity_weighted', 'ACC', 'Kappa']
                
                for metric_name in metric_names:
                    values = [m.get(metric_name) for m in all_metrics if m.get(metric_name) is not None]
                    if values:
                        metric_stats[metric_name] = {
                            'mean': np.mean(values),
                            'std': np.std(values),
                            'count': len(values)
                        }
                    else:
                        metric_stats[metric_name] = {'mean': None, 'std': None, 'count': 0}
                
                # Create result key: model_task
                result_key = f"{model_dir}_{task}"
                all_results[result_key] = {
                    'model': model_dir,
                    'task': task,
                    'metrics': metric_stats,
                    'total_experiments': len(all_metrics)
                }
                
                print(f"      Processed {len(all_metrics)} experiments")
            else:
                print(f"      No valid metrics found")
    
    return all_results

def format_metric_output(mean: float, std: float) -> str:
    """Format metric as mean±std"""
    if mean is not None and std is not None:
        return f"{mean:.4f}±{std:.4f}"
    else:
        return "N/A"

def create_results_table(results: Dict[str, Dict]) -> pd.DataFrame:
    """Create formatted results table"""
    table_data = []
    
    for result_key, result_data in results.items():
        row = {
            'Model': result_data['model']
        }
        
        metrics = result_data['metrics']
        metric_names = ['F1_macro', 'F1_weighted', 'Precision_macro', 
                       'Precision_weighted', 'Recall_macro', 'Recall_weighted', 
                       'Specificity_macro', 'Specificity_weighted', 'ACC', 'Kappa']
        
        for metric_name in metric_names:
            if metric_name in metrics:
                mean = metrics[metric_name]['mean']
                std = metrics[metric_name]['std']
                row[metric_name] = format_metric_output(mean, std)
            else:
                row[metric_name] = "N/A"
        
        row['Experiments'] = result_data['total_experiments']
        table_data.append(row)
    
    df = pd.DataFrame(table_data)
    
    # Define model order
    model_order = ['resnet', 'lstm', 'transformer', 'latefusion', 'medfuse', 'mmtm', 
                   'daft', 'utde', 'shaspec', 'flexmoe', 'drfuse', 'smil', 'healnet', 
                   'm3care', 'umse']
    
    # Create a mapping for custom sort order
    # Handle both regular model names and model names with suffixes like "-fairness"
    model_order_map = {}
    for i, model in enumerate(model_order):
        model_order_map[model] = i
        model_order_map[f"{model}-fairness"] = i
    
    # Sort by model order
    df['model_order'] = df['Model'].map(model_order_map).fillna(999)
    df = df.sort_values('model_order').drop('model_order', axis=1)
    
    return df

def save_task_specific_csvs(df: pd.DataFrame, output_dir: str):
    """Save separate CSV files for each task"""
    # Since we removed Task column, we'll save all results to a single CSV
    # Sort by model order
    model_order = ['resnet', 'lstm', 'transformer', 'latefusion', 'medfuse', 'mmtm', 
                   'daft', 'utde', 'shaspec', 'flexmoe', 'drfuse', 'smil', 'healnet', 
                   'm3care', 'umse']
    
    # Create a mapping for custom sort order
    # Handle both regular model names and model names with suffixes like "-fairness"
    model_order_map = {}
    for i, model in enumerate(model_order):
        model_order_map[model] = i
        model_order_map[f"{model}-fairness"] = i
    
    # Sort by model order
    df['model_order'] = df['Model'].map(model_order_map).fillna(999)
    df = df.sort_values('model_order').drop('model_order', axis=1)
    
    # Save overall results
    output_file = os.path.join(output_dir, "experiment_statistics_los.csv")
    df.to_csv(output_file, index=False)
    print(f"Saved LOS statistics to: {output_file}")
    
    # Display results
    print(f"\nLOS TASK RESULTS:")
    print("-" * 60)
    print(df.to_string(index=False))

def build_attribute_group_tables(summary):
    tables = {}
    for attr in ATTRIBUTES:
        groups = set()
        for md in summary.values():
            for k in md.keys():
                if k.startswith(f'group::{attr}::'):
                    groups.add(k.split('::', 2)[2])
        groups = sorted(groups)
        cols = ['Model'] + [f'{g} ACC' for g in groups]
        rows = []
        for model in MODELS_ORDERED:
            row = {'Model': model}
            for g in groups:
                row[f'{g} ACC'] = summary.get(model, {}).get(f'group::{attr}::{g}', "N/A")
            rows.append(row)
        tables[attr] = pd.DataFrame(rows, columns=cols)
    return tables

def build_intersectional_group_tables(summary):
    tables = {}
    for inter in INTERSECTIONALS:
        groups = set()
        for md in summary.values():
            for k in md.keys():
                if k.startswith(f'inter::{inter}::'):
                    groups.add(k.split('::', 2)[2])
        if not groups:
            continue
        groups = sorted(groups)
        cols = ['Model'] + [f'{g} ACC' for g in groups]
        rows = []
        for model in MODELS_ORDERED:
            row = {'Model': model}
            for g in groups:
                row[f'{g} ACC'] = summary.get(model, {}).get(f'inter::{inter}::{g}', "N/A")
            rows.append(row)
        tables[inter] = pd.DataFrame(rows, columns=cols)
    return tables

def main():
    """Main function to collect and display experiment statistics"""
    experiment_dir = "../experiments_fairness"
    
    if not os.path.exists(experiment_dir):
        print(f"Error: {experiment_dir} directory not found!")
        print("Please run this script from the scripts3 directory")
        return
    
    print("Collecting experiment statistics...")
    print("=" * 80)
    
    # Collect statistics
    results = collect_model_statistics(experiment_dir)
    
    if not results:
        print("No results found!")
        return
    
    # Create results table
    df = create_results_table(results)
    
    # Display overall results
    print("\n" + "=" * 80)
    print("OVERALL EXPERIMENT STATISTICS SUMMARY")
    print("=" * 80)
    print(df.to_string(index=False))
    
    # Save overall results to CSV in experiment_dir
    overall_output_file = os.path.join(experiment_dir, "experiment_statistics_overall.csv")
    df.to_csv(overall_output_file, index=False)
    print(f"\nOverall results saved to: {overall_output_file}")
    
    # Save task-specific CSV files in experiment_dir
    print("\n" + "=" * 80)
    print("SAVING TASK-SPECIFIC CSV FILES")
    print("=" * 80)
    save_task_specific_csvs(df, experiment_dir)
    
    # Display summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    for result_key, result_data in results.items():
        print(f"{result_data['model']} - {result_data['task']}: {result_data['total_experiments']} experiments")
        
        # Show best performance for each metric
        metrics = result_data['metrics']
        metric_names = ['F1_macro', 'F1_weighted', 'Precision_macro', 
                       'Precision_weighted', 'Recall_macro', 'Recall_weighted', 
                       'Specificity_macro', 'Specificity_weighted', 'ACC', 'Kappa']
        
        for metric_name in metric_names:
            if metric_name in metrics and metrics[metric_name]['mean'] is not None:
                mean = metrics[metric_name]['mean']
                std = metrics[metric_name]['std']
                print(f"  {metric_name}: {mean:.4f}±{std:.4f}")

    attr_tables = build_attribute_group_tables(results)
    for attr, adf in attr_tables.items():
        out_csv = os.path.join(experiment_dir, f"los_attr_{attr}_groups_acc_ordered.csv")
        adf.to_csv(out_csv, index=False)
        print(f"Saved attribute groups CSV: {out_csv}")

    inter_tables = build_intersectional_group_tables(results)
    for inter, idf in inter_tables.items():
        out_csv = os.path.join(experiment_dir, f"los_intersectional_{inter}_groups_acc_ordered.csv")
        idf.to_csv(out_csv, index=False)
        print(f"Saved intersectional groups CSV: {out_csv}")

if __name__ == "__main__":
    main()
