import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Set up Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def load_metrics_data(file_path):
    """
    Load prediction metrics data
    
    Parameters:
    file_path: CSV file path
    
    Returns:
    DataFrame
    """
    df = pd.read_csv(file_path)
    return df

def generate_summary_statistics(df):
    """
    Generate summary statistics
    
    Parameters:
    df: DataFrame containing metrics data
    
    Returns:
    Summary statistics DataFrame
    """
    # Group by prediction method and calculate average metrics
    summary = df.groupby('prediction_method')[['auc', 'accuracy']].agg(['mean', 'std', 'count'])
    summary.columns = ['_'.join(col).strip() for col in summary.columns.values]
    
    # Rename columns
    summary = summary.rename(columns={
        'auc_mean': 'avg_auc',
        'auc_std': 'auc_std',
        'accuracy_mean': 'avg_accuracy',
        'accuracy_std': 'accuracy_std',
        'auc_count': 'count',
        'accuracy_count': 'accuracy_count'
    })
    
    # Add confidence intervals (95% CI)
    summary['auc_ci_lower'] = summary['avg_auc'] - 1.96 * summary['auc_std'] / np.sqrt(summary['count'])
    summary['auc_ci_upper'] = summary['avg_auc'] + 1.96 * summary['auc_std'] / np.sqrt(summary['count'])
    summary['accuracy_ci_lower'] = summary['avg_accuracy'] - 1.96 * summary['accuracy_std'] / np.sqrt(summary['count'])
    summary['accuracy_ci_upper'] = summary['avg_accuracy'] + 1.96 * summary['accuracy_std'] / np.sqrt(summary['count'])
    
    # Sort by AUC
    summary = summary.sort_values('avg_auc', ascending=False)
    
    return summary

def generate_train_ratio_analysis(df):
    """
    Analyze metrics by training ratio
    
    Parameters:
    df: DataFrame containing metrics data
    
    Returns:
    Analysis results grouped by training ratio
    """
    # Group by training ratio and prediction method
    ratio_analysis = df.groupby(['train_ratio', 'prediction_method'])[['auc', 'accuracy']].mean().reset_index()
    
    return ratio_analysis

def visualize_metrics(df, output_dir):
    """
    Generate metrics visualization charts
    
    Parameters:
    df: DataFrame containing metrics data
    output_dir: Output directory path
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. AUC and Accuracy comparison for all prediction methods
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    
    # Prediction methods sorted by average AUC
    avg_metrics = df.groupby('prediction_method')[['auc', 'accuracy']].mean().sort_values('auc', ascending=False)
    
    # AUC bar chart
    ax1.bar(range(len(avg_metrics)), avg_metrics['auc'], color='skyblue')
    ax1.set_xticks(range(len(avg_metrics)))
    ax1.set_xticklabels(avg_metrics.index, rotation=45, ha='right')
    ax1.set_ylabel('AUC')
    ax1.set_title('Average AUC for Different Prediction Methods')
    ax1.grid(axis='y', alpha=0.3)
    
    # Accuracy bar chart
    ax2.bar(range(len(avg_metrics)), avg_metrics['accuracy'], color='lightcoral')
    ax2.set_xticks(range(len(avg_metrics)))
    ax2.set_xticklabels(avg_metrics.index, rotation=45, ha='right')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Average Accuracy for Different Prediction Methods')
    ax2.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'prediction_methods_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 2. Metrics trend by training ratio
    ratio_analysis = df.groupby(['train_ratio', 'prediction_method'])[['auc', 'accuracy']].mean().reset_index()
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Get all prediction methods
    methods = ratio_analysis['prediction_method'].unique()
    colors = plt.cm.tab10(np.linspace(0, 1, len(methods)))
    
    # AUC trend chart
    for i, method in enumerate(methods):
        method_data = ratio_analysis[ratio_analysis['prediction_method'] == method]
        ax1.plot(method_data['train_ratio'], method_data['auc'], 
                marker='o', linewidth=2, markersize=6, label=method, color=colors[i])
    
    ax1.set_xlabel('Training Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title('AUC Trend for Different Prediction Methods by Training Ratio')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, alpha=0.3)
    
    # Accuracy trend chart
    for i, method in enumerate(methods):
        method_data = ratio_analysis[ratio_analysis['prediction_method'] == method]
        ax2.plot(method_data['train_ratio'], method_data['accuracy'], 
                marker='s', linewidth=2, markersize=6, label=method, color=colors[i])
    
    ax2.set_xlabel('Training Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy Trend for Different Prediction Methods by Training Ratio')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_trend_by_train_ratio.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 3. Heatmap: Metrics for different methods at various training ratios
    pivot_auc = ratio_analysis.pivot(index='prediction_method', columns='train_ratio', values='auc')
    pivot_accuracy = ratio_analysis.pivot(index='prediction_method', columns='train_ratio', values='accuracy')
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # AUC heatmap
    sns.heatmap(pivot_auc, annot=True, fmt='.3f', cmap='Blues', ax=ax1)
    ax1.set_title('AUC Heatmap')
    ax1.set_xlabel('Training Ratio')
    ax1.set_ylabel('Prediction Method')
    
    # Accuracy heatmap
    sns.heatmap(pivot_accuracy, annot=True, fmt='.3f', cmap='Reds', ax=ax2)
    ax2.set_title('Accuracy Heatmap')
    ax2.set_xlabel('Training Ratio')
    ax2.set_ylabel('Prediction Method')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_heatmap.png'), dpi=300, bbox_inches='tight')
    plt.show()

def save_detailed_results(df, summary, ratio_analysis, output_dir):
    """
    Save detailed results to CSV files
    
    Parameters:
    df: Original metrics data
    summary: Summary statistics
    ratio_analysis: Results of analysis by training ratio
    output_dir: Output directory path
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Save original data
    df.to_csv(os.path.join(output_dir, 'detailed_prediction_metrics.csv'), index=False)
    
    # Save summary statistics
    summary.to_csv(os.path.join(output_dir, 'prediction_metrics_summary.csv'))
    
    # Save results of analysis by training ratio
    ratio_analysis.to_csv(os.path.join(output_dir, 'metrics_by_train_ratio.csv'), index=False)
    
    print(f"Detailed results saved to directory: {output_dir}")

def main():
    # Input and output paths
    input_file = "data/prediction_metrics.csv"
    output_dir = "results/prediction_metrics_analysis"
    
    print("Loading prediction metrics data...")
    df = load_metrics_data(input_file)
    
    # Remove null values
    df = df.dropna(subset=['auc', 'accuracy'])
    
    print("Generating summary statistics...")
    summary = generate_summary_statistics(df)
    
    print("Generating training ratio analysis...")
    ratio_analysis = generate_train_ratio_analysis(df)
    
    print("Generating visualization charts...")
    visualize_metrics(df, output_dir)
    
    print("Saving detailed results...")
    save_detailed_results(df, summary, ratio_analysis, output_dir)
    
    # Print summary statistics
    print("\nPrediction method performance ranking (by AUC):")
    print(summary[['avg_auc', 'auc_std', 'avg_accuracy', 'accuracy_std']])
    
    print("\nPerformance changes by training ratio:")
    print(ratio_analysis.pivot(index='prediction_method', columns='train_ratio', values='auc'))

if __name__ == "__main__":
    main()