#!/usr/bin/env python3
"""
OOD experiment results analysis script
Used to analyze CSV files that store experimental results on different OOD datasets under different random seeds
Group by ood_dataset_name and calculate mean and standard deviation of various metrics
"""

import pandas as pd
import numpy as np
import os
import argparse
from pathlib import Path


def analyze_ood_results(csv_file_path, output_dir=None):
    """
    Analyze OOD experiment results CSV file
    
    Args:
        csv_file_path (str): CSV file path
        output_dir (str): Output directory, if None use the directory where CSV file is located
        
    Returns:
        dict: Dictionary containing statistical results
    """
    # Read CSV file
    print(f"Reading file: {csv_file_path}")
    df = pd.read_csv(csv_file_path)
    
    # Check if ood_dataset_name column exists
    if 'ood_dataset_name' not in df.columns:
        raise ValueError("Column 'ood_dataset_name' not found in CSV file")
    
    # Get all numeric columns (exclude some non-numeric columns)
    exclude_columns = [
        'model_type', 'dataset_name', 'split', 'architecture', 'input_dims', 
        'hidden_dims', 'clf_type', 'loss', 'optimizer_type', 'ood_dataset_name',
        'mix', 'mix_inter', 'mix_noise', 'use_cosine_annealing', 'use_sample_wise_kl_weight'
    ]
    
    # Find all numeric columns
    numeric_columns = []
    for col in df.columns:
        if col not in exclude_columns:
            try:
                pd.to_numeric(df[col], errors='raise')
                numeric_columns.append(col)
            except (ValueError, TypeError):
                # Skip this column if it cannot be converted to numeric
                continue
    
    print(f"Found {len(numeric_columns)} numeric metric columns")
    print(f"Numeric metrics include: {', '.join(numeric_columns[:10])}{'...' if len(numeric_columns) > 10 else ''}")
    
    # Ensure numeric columns are of numeric type
    for col in numeric_columns:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    
    # Group by ood_dataset_name
    grouped = df.groupby('ood_dataset_name')
    
    print(f"\nFound {len(grouped)} different OOD datasets:")
    for name, group in grouped:
        print(f"  - {name}: {len(group)} experimental results")
    
    # Calculate statistical information for each dataset
    results = {}
    summary_stats = []
    
    for ood_dataset, group in grouped:
        print(f"\nProcessing OOD dataset: {ood_dataset}")
        
        # Calculate mean and standard deviation
        stats = {}
        for col in numeric_columns:
            if col in group.columns:
                values = group[col].dropna()
                if len(values) > 0:
                    mean_val = values.mean()
                    std_val = values.std()
                    stats[f'{col}_mean'] = mean_val
                    stats[f'{col}_std'] = std_val
                    
                    # Prepare data for summary table
                    summary_stats.append({
                        'ood_dataset': ood_dataset,
                        'metric': col,
                        'mean': mean_val,
                        'std': std_val,
                        'count': len(values)
                    })
        
        results[ood_dataset] = {
            'count': len(group),
            'stats': stats
        }
    
    # Create summary DataFrame
    summary_df = pd.DataFrame(summary_stats)
    
    # Set output directory
    if output_dir is None:
        output_dir = Path(csv_file_path).parent
    else:
        output_dir = Path(output_dir)
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save detailed statistical results
    detailed_results = []
    for ood_dataset, data in results.items():
        for metric, value in data['stats'].items():
            detailed_results.append({
                'ood_dataset': ood_dataset,
                'metric': metric,
                'value': value,
                'count': data['count']
            })
    
    detailed_df = pd.DataFrame(detailed_results)
    
    # Generate output file names
    base_name = Path(csv_file_path).stem
    
    # Save detailed results
    detailed_output_path = output_dir / f"{base_name}_detailed_stats.csv"
    detailed_df.to_csv(detailed_output_path, index=False)
    print(f"\nDetailed statistical results saved to: {detailed_output_path}")
    
    # Save summary results
    summary_output_path = output_dir / f"{base_name}_summary_stats.csv"
    summary_df.to_csv(summary_output_path, index=False)
    print(f"Summary statistical results saved to: {summary_output_path}")
    
    # Create pivot table grouped by OOD dataset
    # Create separate pivot tables for mean and std
    mean_pivot = summary_df.pivot(index='ood_dataset', columns='metric', values='mean')
    std_pivot = summary_df.pivot(index='ood_dataset', columns='metric', values='std')
    
    # Save pivot tables
    mean_pivot_path = output_dir / f"{base_name}_mean_by_ood.csv"
    std_pivot_path = output_dir / f"{base_name}_std_by_ood.csv"
    
    mean_pivot.to_csv(mean_pivot_path)
    std_pivot.to_csv(std_pivot_path)
    
    print(f"Mean table grouped by OOD dataset saved to: {mean_pivot_path}")
    print(f"Standard deviation table grouped by OOD dataset saved to: {std_pivot_path}")
    
    # Print summary of key metrics
    print(f"\n=== Key Metrics Summary ===")
    key_metrics = ['id_accuracy', 'ood_max_prob_auroc', 'ood_max_alpha_auroc', 'ood_cpu_auroc']
    
    for metric in key_metrics:
        if metric in summary_df['metric'].values:
            metric_data = summary_df[summary_df['metric'] == metric]
            print(f"\n{metric}:")
            for _, row in metric_data.iterrows():
                print(f"  {row['ood_dataset']}: {row['mean']:.4f} ± {row['std']:.4f}")
    
    return {
        'summary_df': summary_df,
        'detailed_df': detailed_df,
        'mean_pivot': mean_pivot,
        'std_pivot': std_pivot,
        'results': results
    }


def main():
    parser = argparse.ArgumentParser(description='Analyze OOD experiment results CSV file')
    parser.add_argument('csv_file', help='CSV file path')
    parser.add_argument('--output-dir', '-o', help='Output directory (default: CSV file directory)')
    
    args = parser.parse_args()
    
    # Check if file exists
    if not os.path.exists(args.csv_file):
        print(f"Error: File {args.csv_file} does not exist")
        return
    
    try:
        analyze_ood_results(args.csv_file, args.output_dir)
        print("\nAnalysis completed!")
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main() 