#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Calculate extended metrics (AUC, Accuracy, etc.) for single benchmark results and generate plots
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
from scipy.stats import pearsonr

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def calculate_metrics_for_prediction(df, pred_column, threshold=0.5):
    """
    Calculate various metrics for a specific prediction column
    """
    y_true = df['true_value'].values
    y_pred_proba = df[pred_column].values
    y_pred = (y_pred_proba >= threshold).astype(int)
    
    # Calculate metrics
    try:
        auc = roc_auc_score(y_true, y_pred_proba)
    except ValueError:
        auc = np.nan  # Handle case where there's only one class
        
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # MSE calculation
    mse = np.mean((y_true - y_pred_proba) ** 2)
    
    return {
        'AUC': auc,
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'MSE': mse
    }

def process_prediction_file(file_path):
    """
    Process a single prediction file and calculate metrics for all prediction methods
    """
    df = pd.read_csv(file_path)
    
    # Prediction columns (excluding true_value)
    pred_columns = [col for col in df.columns if col.endswith('_pred')]
    
    # Calculate metrics for each prediction method
    results = {}
    for col in pred_columns:
        method_name = col.replace('_pred', '')
        results[method_name] = calculate_metrics_for_prediction(df, col)
    
    return results

def aggregate_results_by_train_ratio(base_dir):
    """
    Aggregate results across all train ratios
    """
    predictions_dir = os.path.join(base_dir, "02_sample_predictions")
    metrics_summary = {}
    
    # Get all prediction files
    pred_files = [f for f in os.listdir(predictions_dir) if f.startswith("predictions_ratio_") and f.endswith(".csv")]
    
    for file in pred_files:
        # Extract train ratio from filename
        parts = file.replace("predictions_ratio_", "").replace(".csv", "").split("_")
        train_ratio = float(parts[0])
        repetition = int(parts[2]) if len(parts) > 2 else 1
        
        # Process file
        file_path = os.path.join(predictions_dir, file)
        metrics = process_prediction_file(file_path)
        
        # Store results
        if train_ratio not in metrics_summary:
            metrics_summary[train_ratio] = {}
            
        for method, method_metrics in metrics.items():
            if method not in metrics_summary[train_ratio]:
                metrics_summary[train_ratio][method] = {}
                
            for metric_name, value in method_metrics.items():
                if metric_name not in metrics_summary[train_ratio][method]:
                    metrics_summary[train_ratio][method][metric_name] = []
                metrics_summary[train_ratio][method][metric_name].append(value)
    
    # Calculate mean and std for each metric
    final_summary = {}
    for train_ratio, methods in metrics_summary.items():
        final_summary[train_ratio] = {}
        for method, metrics in methods.items():
            final_summary[train_ratio][method] = {}
            for metric_name, values in metrics.items():
                final_summary[train_ratio][method][metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values)
                }
    
    return final_summary

def save_metrics_summary(summary, output_dir):
    """
    Save metrics summary to CSV files
    """
    metrics_dir = os.path.join(output_dir, "04_metrics")
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Get all metrics and methods
    all_metrics = set()
    all_methods = set()
    
    for train_ratio, methods in summary.items():
        for method, metrics in methods.items():
            all_methods.add(method)
            all_metrics.update(metrics.keys())
    
    all_metrics = sorted(list(all_metrics))
    all_methods = sorted(list(all_methods))
    train_ratios = sorted(list(summary.keys()))
    
    # Create separate CSV files for each metric
    for metric in all_metrics:
        rows = []
        for train_ratio in train_ratios:
            row = {'Train_Ratio': train_ratio}
            for method in all_methods:
                if method in summary[train_ratio] and metric in summary[train_ratio][method]:
                    mean_val = summary[train_ratio][method][metric]['mean']
                    std_val = summary[train_ratio][method][metric]['std']
                    row[method] = f"{mean_val:.6f} ± {std_val:.6f}"
                else:
                    row[method] = "N/A"
            rows.append(row)
        
        df = pd.DataFrame(rows)
        output_file = os.path.join(metrics_dir, f"{metric.lower()}_summary.csv")
        df.to_csv(output_file, index=False)
        print(f"Saved {metric} summary to: {output_file}")

def create_metric_plots(summary, output_dir):
    """
    Create plots for all metrics
    """
    metrics_dir = os.path.join(output_dir, "04_metrics")
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Get all metrics and methods
    all_metrics = set()
    all_methods = set()
    
    for train_ratio, methods in summary.items():
        for method, metrics in methods.items():
            all_methods.add(method)
            all_metrics.update(metrics.keys())
    
    all_metrics = sorted(list(all_metrics))
    all_methods = sorted(list(all_methods))
    train_ratios = sorted(list(summary.keys()))
    
    # Define colors and line styles for different methods
    method_styles = {
        "global_mean": ("blue", "solid"),
        "model_mean": ("orange", "dashed"),
        "question_mean": ("green", "dashdot"),
        "irt_1pl": ("red", "solid"),
        "irt_2pl": ("purple", "dashed")
    }
    
    # Create plots for each metric
    for metric in all_metrics:
        plt.figure(figsize=(10, 6))
        
        for method in all_methods:
            # Extract means and stds
            means = []
            stds = []
            valid_ratios = []
            
            for train_ratio in train_ratios:
                if method in summary[train_ratio] and metric in summary[train_ratio][method]:
                    means.append(summary[train_ratio][method][metric]['mean'])
                    stds.append(summary[train_ratio][method][metric]['std'])
                    valid_ratios.append(train_ratio)
            
            if len(means) > 0:
                # Get style for this method
                color, linestyle = method_styles.get(method, ("black", "solid"))
                
                # Plot with error bars
                plt.errorbar(valid_ratios, means, yerr=stds, 
                           label=method.replace("_", " ").title(),
                           color=color, linestyle=linestyle,
                           linewidth=2, marker="o", markersize=6, capsize=3)
        
        # Formatting
        plt.xlabel("Training Data Ratio", fontsize=12)
        plt.ylabel(metric.upper(), fontsize=12)
        plt.title(f"{metric.upper()} vs. Training Data Ratio", fontsize=14)
        plt.legend(loc="best", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
        plt.yticks(fontsize=10)
        
        # Save plot
        plt.tight_layout()
        output_path = os.path.join(metrics_dir, f"{metric.lower()}_vs_train_ratio.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"Saved {metric} plot to: {output_path}")

def create_combined_metrics_plot(summary, output_dir):
    """
    Create a combined plot showing multiple metrics
    """
    metrics_dir = os.path.join(output_dir, "04_metrics")
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Select key metrics to show
    key_metrics = ['AUC', 'Accuracy', 'MSE']
    all_methods = set()
    
    for train_ratio, methods in summary.items():
        for method, metrics in methods.items():
            all_methods.add(method)
    
    all_methods = sorted(list(all_methods))
    train_ratios = sorted(list(summary.keys()))
    
    # Create subplot for each metric
    fig, axes = plt.subplots(1, len(key_metrics), figsize=(15, 5))
    
    # Define colors for different methods
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_methods)))
    method_colors = dict(zip(all_methods, colors))
    
    for i, metric in enumerate(key_metrics):
        ax = axes[i]
        
        for method in all_methods:
            # Extract means and stds
            means = []
            stds = []
            valid_ratios = []
            
            for train_ratio in train_ratios:
                if method in summary[train_ratio] and metric in summary[train_ratio][method]:
                    means.append(summary[train_ratio][method][metric]['mean'])
                    stds.append(summary[train_ratio][method][metric]['std'])
                    valid_ratios.append(train_ratio)
            
            if len(means) > 0:
                # Plot with error bars
                ax.errorbar(valid_ratios, means, yerr=stds, 
                           label=method.replace("_", " ").title(),
                           color=method_colors[method],
                           linewidth=2, marker="o", markersize=6, capsize=3)
        
        # Formatting
        ax.set_xlabel("Training Data Ratio", fontsize=10)
        ax.set_ylabel(metric.upper(), fontsize=10)
        ax.set_title(f"{metric.upper()}", fontsize=12)
        ax.grid(True, alpha=0.3)
        ax.set_xticks(train_ratios)
        ax.set_xticklabels([f"{r:.1f}" for r in train_ratios], fontsize=8)
        ax.tick_params(axis='y', labelsize=8)
    
    # Add legend to the last subplot
    axes[-1].legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(metrics_dir, "combined_metrics.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved combined metrics plot to: {output_path}")

def create_performance_comparison_table(summary, output_dir):
    """
    Create a performance comparison table
    """
    metrics_dir = os.path.join(output_dir, "04_metrics")
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Get all methods and train ratios
    all_methods = set()
    for train_ratio, methods in summary.items():
        for method, metrics in methods.items():
            all_methods.add(method)
    
    all_methods = sorted(list(all_methods))
    train_ratios = sorted(list(summary.keys()))
    
    # Create comparison table for key metrics at highest train ratio
    max_train_ratio = max(train_ratios)
    if max_train_ratio in summary:
        comparison_data = []
        for method in all_methods:
            if method in summary[max_train_ratio]:
                row = {'Method': method.replace("_", " ").title()}
                for metric in ['AUC', 'Accuracy', 'MSE']:
                    if metric in summary[max_train_ratio][method]:
                        mean_val = summary[max_train_ratio][method][metric]['mean']
                        std_val = summary[max_train_ratio][method][metric]['std']
                        row[metric] = f"{mean_val:.4f} ± {std_val:.4f}"
                    else:
                        row[metric] = "N/A"
                comparison_data.append(row)
        
        if comparison_data:
            df_comparison = pd.DataFrame(comparison_data)
            output_file = os.path.join(metrics_dir, "performance_comparison.csv")
            df_comparison.to_csv(output_file, index=False)
            print(f"Saved performance comparison to: {output_file}")

def main():
    # Define paths
    base_dir = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/IRT/mix_benchmark/result_single_benchmark"
    output_dir = base_dir
    
    print("Processing single benchmark results...")
    print(f"Base directory: {base_dir}")
    
    # Aggregate results
    print("Aggregating results...")
    summary = aggregate_results_by_train_ratio(base_dir)
    
    # Save metrics summary
    print("Saving metrics summary...")
    save_metrics_summary(summary, output_dir)
    
    # Create plots
    print("Creating metric plots...")
    create_metric_plots(summary, output_dir)
    
    # Create combined plot
    print("Creating combined metrics plot...")
    create_combined_metrics_plot(summary, output_dir)
    
    # Create performance comparison table
    print("Creating performance comparison table...")
    create_performance_comparison_table(summary, output_dir)
    
    print(f"\nAll extended metrics and plots saved to: {output_dir}/04_metrics")

if __name__ == "__main__":
    main()