#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Create benchmark-wise comparison plots for Mixed vs Single approaches
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_benchmark_data(benchmark_name, approach):
    """
    Load data for a specific benchmark and approach
    """
    if approach == "MIXED":
        # For mixed approach, we use the improved mixed benchmark results
        auc_file = "yourpath/result_improved_mixed_benchmark/04_metrics/auc_summary.csv"
        acc_file = "yourpath/result_improved_mixed_benchmark/04_metrics/accuracy_summary.csv"
    else:
        # For single approach, we need to extract data for specific benchmark
        auc_file = "yourpath/result_single_benchmark/04_metrics/auc_summary.csv"
        acc_file = "yourpath/result_single_benchmark/04_metrics/accuracy_summary.csv"
    
    if os.path.exists(auc_file) and os.path.exists(acc_file):
        auc_df = pd.read_csv(auc_file)
        acc_df = pd.read_csv(acc_file)
        return auc_df, acc_df
    else:
        print(f"Files not found for {benchmark_name} - {approach}")
        return None, None

def create_benchmark_wise_auc_comparison(benchmark_name, mixed_auc_df, single_auc_df, output_dir):
    """
    Create AUC comparison for a specific benchmark
    """
    plt.figure(figsize=(12, 8))
    
    # Define styles
    styles = {
        "Mixed Approach": ("red", "solid", "o", 3),
        "Single Approach": ("blue", "dashed", "s", 3)
    }
    
    # Plot mixed approach data
    if mixed_auc_df is not None and 'multibench_irt' in mixed_auc_df.columns:
        ratios = mixed_auc_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in mixed_auc_df['multibench_irt'].values]
        plt.plot(ratios, auc_values, 
                label="Mixed Approach (Multi-IRT)", 
                color=styles["Mixed Approach"][0], 
                linestyle=styles["Mixed Approach"][1],
                linewidth=styles["Mixed Approach"][3], 
                marker=styles["Mixed Approach"][2], 
                markersize=10)
    
    # Plot single approach data
    if single_auc_df is not None and 'irt_1pl' in single_auc_df.columns:
        ratios = single_auc_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in single_auc_df['irt_1pl'].values]
        plt.plot(ratios, auc_values, 
                label="Single Approach (IRT-1PL)", 
                color=styles["Single Approach"][0], 
                linestyle=styles["Single Approach"][1],
                linewidth=styles["Single Approach"][3], 
                marker=styles["Single Approach"][2], 
                markersize=10)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=16)
    plt.ylabel("AUC", fontsize=16)
    plt.title(f"AUC Comparison: {benchmark_name} Benchmark (Mixed vs Single)", fontsize=18, pad=20)
    plt.legend(loc="lower right", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.7, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_auc_df is not None:
        all_ratios.update(mixed_auc_df['Train_Ratio'].values)
    if single_auc_df is not None:
        all_ratios.update(single_auc_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=14)
    plt.yticks(fontsize=14)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{benchmark_name.lower()}_auc_mix_vs_single.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved {benchmark_name} AUC comparison plot to: {output_path}")

def create_benchmark_wise_accuracy_comparison(benchmark_name, mixed_acc_df, single_acc_df, output_dir):
    """
    Create Accuracy comparison for a specific benchmark
    """
    plt.figure(figsize=(12, 8))
    
    # Define styles
    styles = {
        "Mixed Approach": ("red", "solid", "o", 3),
        "Single Approach": ("blue", "dashed", "s", 3)
    }
    
    # Plot mixed approach data
    if mixed_acc_df is not None and 'multibench_irt' in mixed_acc_df.columns:
        ratios = mixed_acc_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in mixed_acc_df['multibench_irt'].values]
        plt.plot(ratios, acc_values, 
                label="Mixed Approach (Multi-IRT)", 
                color=styles["Mixed Approach"][0], 
                linestyle=styles["Mixed Approach"][1],
                linewidth=styles["Mixed Approach"][3], 
                marker=styles["Mixed Approach"][2], 
                markersize=10)
    
    # Plot single approach data
    if single_acc_df is not None and 'irt_1pl' in single_acc_df.columns:
        ratios = single_acc_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in single_acc_df['irt_1pl'].values]
        plt.plot(ratios, acc_values, 
                label="Single Approach (IRT-1PL)", 
                color=styles["Single Approach"][0], 
                linestyle=styles["Single Approach"][1],
                linewidth=styles["Single Approach"][3], 
                marker=styles["Single Approach"][2], 
                markersize=10)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=16)
    plt.ylabel("Accuracy", fontsize=16)
    plt.title(f"Accuracy Comparison: {benchmark_name} Benchmark (Mixed vs Single)", fontsize=18, pad=20)
    plt.legend(loc="lower right", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.5, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_acc_df is not None:
        all_ratios.update(mixed_acc_df['Train_Ratio'].values)
    if single_acc_df is not None:
        all_ratios.update(single_acc_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=14)
    plt.yticks(fontsize=14)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{benchmark_name.lower()}_accuracy_mix_vs_single.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved {benchmark_name} Accuracy comparison plot to: {output_path}")

def create_combined_benchmark_comparison(benchmark_name, mixed_auc_df, single_auc_df, mixed_acc_df, single_acc_df, output_dir):
    """
    Create a combined plot showing both AUC and Accuracy for a specific benchmark
    """
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle(f"Metrics Comparison: {benchmark_name} Benchmark (Mixed vs Single)", fontsize=20, y=1.02)
    
    # Define styles
    styles = {
        "Mixed Approach": ("red", "solid", "o", 3),
        "Single Approach": ("blue", "dashed", "s", 3)
    }
    
    # Plot AUC comparison
    ax1 = axes[0]
    
    # Plot mixed approach AUC data
    if mixed_auc_df is not None and 'multibench_irt' in mixed_auc_df.columns:
        ratios = mixed_auc_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in mixed_auc_df['multibench_irt'].values]
        ax1.plot(ratios, auc_values, 
                label="Mixed Approach (Multi-IRT)", 
                color=styles["Mixed Approach"][0], 
                linestyle=styles["Mixed Approach"][1],
                linewidth=styles["Mixed Approach"][3], 
                marker=styles["Mixed Approach"][2], 
                markersize=10)
    
    # Plot single approach AUC data
    if single_auc_df is not None and 'irt_1pl' in single_auc_df.columns:
        ratios = single_auc_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in single_auc_df['irt_1pl'].values]
        ax1.plot(ratios, auc_values, 
                label="Single Approach (IRT-1PL)", 
                color=styles["Single Approach"][0], 
                linestyle=styles["Single Approach"][1],
                linewidth=styles["Single Approach"][3], 
                marker=styles["Single Approach"][2], 
                markersize=10)
    
    ax1.set_xlabel("Training Data Ratio", fontsize=14)
    ax1.set_ylabel("AUC", fontsize=14)
    ax1.set_title("AUC Comparison", fontsize=16)
    ax1.legend(loc="lower right", fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0.7, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_auc_df is not None:
        all_ratios.update(mixed_auc_df['Train_Ratio'].values)
    if single_auc_df is not None:
        all_ratios.update(single_auc_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    ax1.set_xticks(sorted_ratios)
    ax1.set_xticklabels([f"{r:.1f}" for r in sorted_ratios], fontsize=12)
    ax1.tick_params(axis='y', labelsize=12)
    
    # Plot Accuracy comparison
    ax2 = axes[1]
    
    # Plot mixed approach Accuracy data
    if mixed_acc_df is not None and 'multibench_irt' in mixed_acc_df.columns:
        ratios = mixed_acc_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in mixed_acc_df['multibench_irt'].values]
        ax2.plot(ratios, acc_values, 
                label="Mixed Approach (Multi-IRT)", 
                color=styles["Mixed Approach"][0], 
                linestyle=styles["Mixed Approach"][1],
                linewidth=styles["Mixed Approach"][3], 
                marker=styles["Mixed Approach"][2], 
                markersize=10)
    
    # Plot single approach Accuracy data
    if single_acc_df is not None and 'irt_1pl' in single_acc_df.columns:
        ratios = single_acc_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in single_acc_df['irt_1pl'].values]
        ax2.plot(ratios, acc_values, 
                label="Single Approach (IRT-1PL)", 
                color=styles["Single Approach"][0], 
                linestyle=styles["Single Approach"][1],
                linewidth=styles["Single Approach"][3], 
                marker=styles["Single Approach"][2], 
                markersize=10)
    
    ax2.set_xlabel("Training Data Ratio", fontsize=14)
    ax2.set_ylabel("Accuracy", fontsize=14)
    ax2.set_title("Accuracy Comparison", fontsize=16)
    ax2.legend(loc="lower right", fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    ax2.set_xticks(sorted_ratios)
    ax2.set_xticklabels([f"{r:.1f}" for r in sorted_ratios], fontsize=12)
    ax2.tick_params(axis='y', labelsize=12)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{benchmark_name.lower()}_combined_metrics_mix_vs_single.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved {benchmark_name} combined metrics comparison plot to: {output_path}")

def create_benchmark_comparison_table(benchmark_name, mixed_auc_df, single_auc_df, mixed_acc_df, single_acc_df, output_dir):
    """
    Create a detailed comparison table for a specific benchmark
    """
    # Collect all train ratios
    all_ratios = set()
    if mixed_auc_df is not None:
        all_ratios.update(mixed_auc_df['Train_Ratio'].values)
    if single_auc_df is not None:
        all_ratios.update(single_auc_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    
    # Create comparison table
    comparison_data = []
    for ratio in sorted_ratios:
        row = {'Train_Ratio': ratio}
        
        # Add mixed approach data
        if mixed_auc_df is not None and ratio in mixed_auc_df['Train_Ratio'].values:
            mixed_auc_row = mixed_auc_df[mixed_auc_df['Train_Ratio'] == ratio].iloc[0]
            mixed_acc_row = mixed_acc_df[mixed_acc_df['Train_Ratio'] == ratio].iloc[0] if mixed_acc_df is not None else None
            
            if 'multibench_irt' in mixed_auc_row:
                row['Mixed Approach (AUC)'] = mixed_auc_row['multibench_irt']
            else:
                row['Mixed Approach (AUC)'] = "N/A"
                
            if mixed_acc_row is not None and 'multibench_irt' in mixed_acc_row:
                row['Mixed Approach (Accuracy)'] = mixed_acc_row['multibench_irt']
            else:
                row['Mixed Approach (Accuracy)'] = "N/A"
        else:
            row['Mixed Approach (AUC)'] = "N/A"
            row['Mixed Approach (Accuracy)'] = "N/A"
        
        # Add single approach data
        if single_auc_df is not None and ratio in single_auc_df['Train_Ratio'].values:
            single_auc_row = single_auc_df[single_auc_df['Train_Ratio'] == ratio].iloc[0]
            single_acc_row = single_acc_df[single_acc_df['Train_Ratio'] == ratio].iloc[0] if single_acc_df is not None else None
            
            if 'irt_1pl' in single_auc_row:
                row['Single Approach (AUC)'] = single_auc_row['irt_1pl']
            else:
                row['Single Approach (AUC)'] = "N/A"
                
            if single_acc_row is not None and 'irt_1pl' in single_acc_row:
                row['Single Approach (Accuracy)'] = single_acc_row['irt_1pl']
            else:
                row['Single Approach (Accuracy)'] = "N/A"
        else:
            row['Single Approach (AUC)'] = "N/A"
            row['Single Approach (Accuracy)'] = "N/A"
        
        comparison_data.append(row)
    
    # Create and save the comparison table
    df_comparison = pd.DataFrame(comparison_data)
    output_file = os.path.join(output_dir, f"{benchmark_name.lower()}_metrics_mix_vs_single.csv")
    df_comparison.to_csv(output_file, index=False)
    print(f"Saved {benchmark_name} metrics comparison table to: {output_file}")
    
    return df_comparison

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    # Define benchmarks
    benchmarks = ["CEVAL", "CSQA", "MMLU"]
    
    print("Creating benchmark-wise comparison of Mixed vs Single approaches...")
    
    for benchmark in benchmarks:
        print(f"\nProcessing {benchmark} benchmark...")
        
        # Load data for mixed and single approaches
        mixed_auc_df, mixed_acc_df = load_benchmark_data(benchmark, "MIXED")
        single_auc_df, single_acc_df = load_benchmark_data(benchmark, "SINGLE")
        
        # Create AUC comparison plot
        print(f"Creating {benchmark} AUC comparison plot...")
        create_benchmark_wise_auc_comparison(benchmark, mixed_auc_df, single_auc_df, output_dir)
        
        # Create Accuracy comparison plot
        print(f"Creating {benchmark} Accuracy comparison plot...")
        create_benchmark_wise_accuracy_comparison(benchmark, mixed_acc_df, single_acc_df, output_dir)
        
        # Create combined metrics plot
        print(f"Creating {benchmark} combined metrics comparison plot...")
        create_combined_benchmark_comparison(benchmark, mixed_auc_df, single_auc_df, mixed_acc_df, single_acc_df, output_dir)
        
        # Create comparison table
        print(f"Creating {benchmark} metrics comparison table...")
        create_benchmark_comparison_table(benchmark, mixed_auc_df, single_auc_df, mixed_acc_df, single_acc_df, output_dir)
    
    print(f"\nAll benchmark-wise comparison results saved to: {output_dir}")

if __name__ == "__main__":
    main()