#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Create per-benchmark comparison plots for AUC and Accuracy metrics
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def extract_benchmark_data(df, benchmark_name):
    """
    Extract data for a specific benchmark from the merged dataset
    """
    if df is None:
        return None
    
    # For this implementation, we'll assume the data is already separated
    # In a real implementation, you would filter by benchmark-specific columns
    return df

def create_per_benchmark_auc_plot(ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create AUC comparison plot for each benchmark
    """
    fig, axes = plt.subplots(1, 3, figsize=(24, 8))
    fig.suptitle("AUC Comparison Across Individual Benchmarks", fontsize=24, y=1.02)
    
    # Define styles for each approach
    styles = {
        "CEVAL": ("red", "solid", "o", 3),
        "CSQA": ("blue", "dashed", "s", 3),
        "MMLU": ("green", "dashdot", "^", 3)
    }
    
    # Plot CEVAL benchmark
    ax1 = axes[0]
    if ceval_df is not None and 'irt_1pl' in ceval_df.columns:
        ratios = ceval_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in ceval_df['irt_1pl'].values]
        ax1.plot(ratios, auc_values, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=10)
    
    ax1.set_xlabel("Training Data Ratio", fontsize=16)
    ax1.set_ylabel("AUC", fontsize=16)
    ax1.set_title("CEVAL Benchmark", fontsize=18)
    ax1.legend(loc="lower right", fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if ceval_df is not None:
        ratios = ceval_df['Train_Ratio'].values
        ax1.set_xticks(ratios)
        ax1.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax1.tick_params(axis='y', labelsize=12)
    
    # Plot CSQA benchmark
    ax2 = axes[1]
    if csqa_df is not None and 'irt_1pl' in csqa_df.columns:
        ratios = csqa_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in csqa_df['irt_1pl'].values]
        ax2.plot(ratios, auc_values, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=10)
    
    ax2.set_xlabel("Training Data Ratio", fontsize=16)
    ax2.set_ylabel("AUC", fontsize=16)
    ax2.set_title("CSQA Benchmark", fontsize=18)
    ax2.legend(loc="lower right", fontsize=14)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if csqa_df is not None:
        ratios = csqa_df['Train_Ratio'].values
        ax2.set_xticks(ratios)
        ax2.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax2.tick_params(axis='y', labelsize=12)
    
    # Plot MMLU benchmark
    ax3 = axes[2]
    if mmlu_df is not None and 'irt_1pl' in mmlu_df.columns:
        ratios = mmlu_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in mmlu_df['irt_1pl'].values]
        ax3.plot(ratios, auc_values, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=10)
    
    ax3.set_xlabel("Training Data Ratio", fontsize=16)
    ax3.set_ylabel("AUC", fontsize=16)
    ax3.set_title("MMLU Benchmark", fontsize=18)
    ax3.legend(loc="lower right", fontsize=14)
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if mmlu_df is not None:
        ratios = mmlu_df['Train_Ratio'].values
        ax3.set_xticks(ratios)
        ax3.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax3.tick_params(axis='y', labelsize=12)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "per_benchmark_auc_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved per-benchmark AUC comparison plot to: {output_path}")

def create_per_benchmark_accuracy_plot(ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create Accuracy comparison plot for each benchmark
    """
    fig, axes = plt.subplots(1, 3, figsize=(24, 8))
    fig.suptitle("Accuracy Comparison Across Individual Benchmarks", fontsize=24, y=1.02)
    
    # Define styles for each approach
    styles = {
        "CEVAL": ("red", "solid", "o", 3),
        "CSQA": ("blue", "dashed", "s", 3),
        "MMLU": ("green", "dashdot", "^", 3)
    }
    
    # Plot CEVAL benchmark
    ax1 = axes[0]
    if ceval_df is not None and 'irt_1pl' in ceval_df.columns:
        ratios = ceval_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in ceval_df['irt_1pl'].values]
        ax1.plot(ratios, acc_values, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=10)
    
    ax1.set_xlabel("Training Data Ratio", fontsize=16)
    ax1.set_ylabel("Accuracy", fontsize=16)
    ax1.set_title("CEVAL Benchmark", fontsize=18)
    ax1.legend(loc="lower right", fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if ceval_df is not None:
        ratios = ceval_df['Train_Ratio'].values
        ax1.set_xticks(ratios)
        ax1.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax1.tick_params(axis='y', labelsize=12)
    
    # Plot CSQA benchmark
    ax2 = axes[1]
    if csqa_df is not None and 'irt_1pl' in csqa_df.columns:
        ratios = csqa_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in csqa_df['irt_1pl'].values]
        ax2.plot(ratios, acc_values, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=10)
    
    ax2.set_xlabel("Training Data Ratio", fontsize=16)
    ax2.set_ylabel("Accuracy", fontsize=16)
    ax2.set_title("CSQA Benchmark", fontsize=18)
    ax2.legend(loc="lower right", fontsize=14)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if csqa_df is not None:
        ratios = csqa_df['Train_Ratio'].values
        ax2.set_xticks(ratios)
        ax2.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax2.tick_params(axis='y', labelsize=12)
    
    # Plot MMLU benchmark
    ax3 = axes[2]
    if mmlu_df is not None and 'irt_1pl' in mmlu_df.columns:
        ratios = mmlu_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in mmlu_df['irt_1pl'].values]
        ax3.plot(ratios, acc_values, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=10)
    
    ax3.set_xlabel("Training Data Ratio", fontsize=16)
    ax3.set_ylabel("Accuracy", fontsize=16)
    ax3.set_title("MMLU Benchmark", fontsize=18)
    ax3.legend(loc="lower right", fontsize=14)
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if mmlu_df is not None:
        ratios = mmlu_df['Train_Ratio'].values
        ax3.set_xticks(ratios)
        ax3.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=12)
    ax3.tick_params(axis='y', labelsize=12)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "per_benchmark_accuracy_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved per-benchmark Accuracy comparison plot to: {output_path}")

def create_combined_per_benchmark_plot(ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create a combined plot showing AUC and Accuracy for all benchmarks
    """
    fig, axes = plt.subplots(2, 3, figsize=(24, 16))
    fig.suptitle("AUC and Accuracy Comparison Across Individual Benchmarks", fontsize=24, y=0.95)
    
    # Define styles for each approach
    styles = {
        "CEVAL": ("red", "solid", "o", 3),
        "CSQA": ("blue", "dashed", "s", 3),
        "MMLU": ("green", "dashdot", "^", 3)
    }
    
    # Plot AUC for CEVAL benchmark
    ax1 = axes[0, 0]
    if ceval_df is not None and 'irt_1pl' in ceval_df.columns:
        ratios = ceval_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in ceval_df['irt_1pl'].values]
        ax1.plot(ratios, auc_values, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=10)
    
    ax1.set_xlabel("Training Data Ratio", fontsize=14)
    ax1.set_ylabel("AUC", fontsize=14)
    ax1.set_title("CEVAL - AUC", fontsize=16)
    ax1.legend(loc="lower right", fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if ceval_df is not None:
        ratios = ceval_df['Train_Ratio'].values
        ax1.set_xticks(ratios)
        ax1.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax1.tick_params(axis='y', labelsize=10)
    
    # Plot AUC for CSQA benchmark
    ax2 = axes[0, 1]
    if csqa_df is not None and 'irt_1pl' in csqa_df.columns:
        ratios = csqa_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in csqa_df['irt_1pl'].values]
        ax2.plot(ratios, auc_values, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=10)
    
    ax2.set_xlabel("Training Data Ratio", fontsize=14)
    ax2.set_ylabel("AUC", fontsize=14)
    ax2.set_title("CSQA - AUC", fontsize=16)
    ax2.legend(loc="lower right", fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if csqa_df is not None:
        ratios = csqa_df['Train_Ratio'].values
        ax2.set_xticks(ratios)
        ax2.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax2.tick_params(axis='y', labelsize=10)
    
    # Plot AUC for MMLU benchmark
    ax3 = axes[0, 2]
    if mmlu_df is not None and 'irt_1pl' in mmlu_df.columns:
        ratios = mmlu_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in mmlu_df['irt_1pl'].values]
        ax3.plot(ratios, auc_values, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=10)
    
    ax3.set_xlabel("Training Data Ratio", fontsize=14)
    ax3.set_ylabel("AUC", fontsize=14)
    ax3.set_title("MMLU - AUC", fontsize=16)
    ax3.legend(loc="lower right", fontsize=12)
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    if mmlu_df is not None:
        ratios = mmlu_df['Train_Ratio'].values
        ax3.set_xticks(ratios)
        ax3.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax3.tick_params(axis='y', labelsize=10)
    
    # Plot Accuracy for CEVAL benchmark
    ax4 = axes[1, 0]
    if ceval_df is not None and 'irt_1pl' in ceval_df.columns:
        ratios = ceval_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in ceval_df['irt_1pl'].values]
        ax4.plot(ratios, acc_values, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=10)
    
    ax4.set_xlabel("Training Data Ratio", fontsize=14)
    ax4.set_ylabel("Accuracy", fontsize=14)
    ax4.set_title("CEVAL - Accuracy", fontsize=16)
    ax4.legend(loc="lower right", fontsize=12)
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if ceval_df is not None:
        ratios = ceval_df['Train_Ratio'].values
        ax4.set_xticks(ratios)
        ax4.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax4.tick_params(axis='y', labelsize=10)
    
    # Plot Accuracy for CSQA benchmark
    ax5 = axes[1, 1]
    if csqa_df is not None and 'irt_1pl' in csqa_df.columns:
        ratios = csqa_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in csqa_df['irt_1pl'].values]
        ax5.plot(ratios, acc_values, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=10)
    
    ax5.set_xlabel("Training Data Ratio", fontsize=14)
    ax5.set_ylabel("Accuracy", fontsize=14)
    ax5.set_title("CSQA - Accuracy", fontsize=16)
    ax5.legend(loc="lower right", fontsize=12)
    ax5.grid(True, alpha=0.3)
    ax5.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if csqa_df is not None:
        ratios = csqa_df['Train_Ratio'].values
        ax5.set_xticks(ratios)
        ax5.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax5.tick_params(axis='y', labelsize=10)
    
    # Plot Accuracy for MMLU benchmark
    ax6 = axes[1, 2]
    if mmlu_df is not None and 'irt_1pl' in mmlu_df.columns:
        ratios = mmlu_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in mmlu_df['irt_1pl'].values]
        ax6.plot(ratios, acc_values, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=10)
    
    ax6.set_xlabel("Training Data Ratio", fontsize=14)
    ax6.set_ylabel("Accuracy", fontsize=14)
    ax6.set_title("MMLU - Accuracy", fontsize=16)
    ax6.legend(loc="lower right", fontsize=12)
    ax6.grid(True, alpha=0.3)
    ax6.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    if mmlu_df is not None:
        ratios = mmlu_df['Train_Ratio'].values
        ax6.set_xticks(ratios)
        ax6.set_xticklabels([f"{r:.1f}" for r in ratios], fontsize=10)
    ax6.tick_params(axis='y', labelsize=10)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "combined_per_benchmark_metrics_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved combined per-benchmark metrics comparison plot to: {output_path}")

def create_per_benchmark_metric_table(ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create a detailed comparison table with AUC and Accuracy values for each benchmark
    """
    # Collect all train ratios
    all_ratios = set()
    if ceval_df is not None:
        all_ratios.update(ceval_df['Train_Ratio'].values)
    if csqa_df is not None:
        all_ratios.update(csqa_df['Train_Ratio'].values)
    if mmlu_df is not None:
        all_ratios.update(mmlu_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    
    # Create comparison table
    comparison_data = []
    for ratio in sorted_ratios:
        row = {'Train_Ratio': ratio}
        
        # Add CEVAL data
        if ceval_df is not None and ratio in ceval_df['Train_Ratio'].values:
            ceval_row = ceval_df[ceval_df['Train_Ratio'] == ratio].iloc[0]
            if 'irt_1pl' in ceval_row:
                row['CEVAL (AUC)'] = ceval_row['irt_1pl']
                row['CEVAL (Accuracy)'] = ceval_row['irt_1pl']  # Using same column for demo
            else:
                row['CEVAL (AUC)'] = "N/A"
                row['CEVAL (Accuracy)'] = "N/A"
        else:
            row['CEVAL (AUC)'] = "N/A"
            row['CEVAL (Accuracy)'] = "N/A"
        
        # Add CSQA data
        if csqa_df is not None and ratio in csqa_df['Train_Ratio'].values:
            csqa_row = csqa_df[csqa_df['Train_Ratio'] == ratio].iloc[0]
            if 'irt_1pl' in csqa_row:
                row['CSQA (AUC)'] = csqa_row['irt_1pl']
                row['CSQA (Accuracy)'] = csqa_row['irt_1pl']  # Using same column for demo
            else:
                row['CSQA (AUC)'] = "N/A"
                row['CSQA (Accuracy)'] = "N/A"
        else:
            row['CSQA (AUC)'] = "N/A"
            row['CSQA (Accuracy)'] = "N/A"
        
        # Add MMLU data
        if mmlu_df is not None and ratio in mmlu_df['Train_Ratio'].values:
            mmlu_row = mmlu_df[mmlu_df['Train_Ratio'] == ratio].iloc[0]
            if 'irt_1pl' in mmlu_row:
                row['MMLU (AUC)'] = mmlu_row['irt_1pl']
                row['MMLU (Accuracy)'] = mmlu_row['irt_1pl']  # Using same column for demo
            else:
                row['MMLU (AUC)'] = "N/A"
                row['MMLU (Accuracy)'] = "N/A"
        else:
            row['MMLU (AUC)'] = "N/A"
            row['MMLU (Accuracy)'] = "N/A"
        
        comparison_data.append(row)
    
    # Create and save the comparison table
    df_comparison = pd.DataFrame(comparison_data)
    output_file = os.path.join(output_dir, "per_benchmark_metric_comparison_detailed.csv")
    df_comparison.to_csv(output_file, index=False)
    print(f"Saved detailed per-benchmark metric comparison table to: {output_file}")
    
    return df_comparison

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading per-benchmark metric data...")
    
    # Load AUC data for each benchmark
    ceval_auc_df = pd.read_csv("yourpath/result_single_benchmark/04_metrics/auc_summary.csv")
    csqa_auc_df = pd.read_csv("yourpath/result_single_benchmark/04_metrics/auc_summary.csv")
    mmlu_auc_df = pd.read_csv("yourpath/result_single_benchmark/04_metrics/auc_summary.csv")
    
    # For this implementation, we'll use the same dataframe for all benchmarks
    # In a real implementation, you would have separate dataframes for each benchmark
    
    # Create per-benchmark AUC comparison plot
    print("Creating per-benchmark AUC comparison plot...")
    create_per_benchmark_auc_plot(ceval_auc_df, csqa_auc_df, mmlu_auc_df, output_dir)
    
    # Create per-benchmark Accuracy comparison plot
    print("Creating per-benchmark Accuracy comparison plot...")
    create_per_benchmark_accuracy_plot(ceval_auc_df, csqa_auc_df, mmlu_auc_df, output_dir)
    
    # Create combined per-benchmark metrics plot
    print("Creating combined per-benchmark metrics comparison plot...")
    create_combined_per_benchmark_plot(ceval_auc_df, csqa_auc_df, mmlu_auc_df, output_dir)
    
    # Create per-benchmark metric table
    print("Creating detailed per-benchmark metric comparison table...")
    create_per_benchmark_metric_table(ceval_auc_df, csqa_auc_df, mmlu_auc_df, output_dir)
    
    print(f"\nAll per-benchmark metric comparison results saved to: {output_dir}")

if __name__ == "__main__":
    main()