#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Compare single benchmark results with mixed benchmark results using detailed train ratios
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_mse_data(benchmark_name):
    """
    Load MSE data for a specific benchmark
    """
    if benchmark_name == "MIXED":
        file_path = "yourpath/result_mixed_benchmark/04_metrics/mse_summary.csv"
    else:
        file_path = "yourpath/result_single_benchmark/04_metrics/mse_summary.csv"
    
    if os.path.exists(file_path):
        return pd.read_csv(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

def create_detailed_comparison_plot(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create detailed comparison plot of MSE across different benchmarks with 10 train ratios
    """
    plt.figure(figsize=(14, 10))
    
    # Define styles for each benchmark
    styles = {
        "Mixed Benchmark": ("red", "solid", "o", 3),
        "CEVAL": ("blue", "dashed", "s", 3),
        "CSQA": ("green", "dashdot", "^", 3),
        "MMLU": ("purple", "dotted", "d", 3)
    }
    
    # Extract train ratios and MSE values for each benchmark
    if mixed_df is not None and 'Multi-benchmark IRT' in mixed_df.columns:
        mixed_ratios = mixed_df['Train_Ratio'].values
        mixed_mse = [float(x.split(' ± ')[0]) for x in mixed_df['Multi-benchmark IRT'].values]
        plt.plot(mixed_ratios, mixed_mse, 
                label="Mixed Benchmark (Multi-IRT)", 
                color=styles["Mixed Benchmark"][0], 
                linestyle=styles["Mixed Benchmark"][1],
                linewidth=styles["Mixed Benchmark"][3], 
                marker=styles["Mixed Benchmark"][2], 
                markersize=12)
    
    # For single benchmarks, we'll use IRT-1PL results
    if ceval_df is not None and 'IRT-1PL' in ceval_df.columns:
        ceval_ratios = ceval_df['Train_Ratio'].values
        ceval_mse = [float(x.split(' ± ')[0]) for x in ceval_df['IRT-1PL'].values]
        plt.plot(ceval_ratios, ceval_mse, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=12)
    
    if csqa_df is not None and 'IRT-1PL' in csqa_df.columns:
        csqa_ratios = csqa_df['Train_Ratio'].values
        csqa_mse = [float(x.split(' ± ')[0]) for x in csqa_df['IRT-1PL'].values]
        plt.plot(csqa_ratios, csqa_mse, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=12)
    
    if mmlu_df is not None and 'IRT-1PL' in mmlu_df.columns:
        mmlu_ratios = mmlu_df['Train_Ratio'].values
        mmlu_mse = [float(x.split(' ± ')[0]) for x in mmlu_df['IRT-1PL'].values]
        plt.plot(mmlu_ratios, mmlu_mse, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=12)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=18)
    plt.ylabel("MSE", fontsize=18)
    plt.title("MSE Comparison: Mixed Benchmark vs Single Benchmarks (10 Train Ratios)", fontsize=22, pad=25)
    plt.legend(loc="upper right", fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 0.5)
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_df is not None:
        all_ratios.update(mixed_ratios)
    if ceval_df is not None:
        all_ratios.update(ceval_ratios)
    if csqa_df is not None:
        all_ratios.update(csqa_ratios)
    if mmlu_df is not None:
        all_ratios.update(mmlu_ratios)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=16)
    plt.yticks(fontsize=16)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "single_vs_mixed_mse_comparison_detailed.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved detailed MSE comparison plot to: {output_path}")

def create_zoomed_comparison_plot(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create a zoomed comparison plot focusing on the lower MSE range
    """
    plt.figure(figsize=(14, 10))
    
    # Define styles for each benchmark
    styles = {
        "Mixed Benchmark": ("red", "solid", "o", 3),
        "CEVAL": ("blue", "dashed", "s", 3),
        "CSQA": ("green", "dashdot", "^", 3),
        "MMLU": ("purple", "dotted", "d", 3)
    }
    
    # Extract train ratios and MSE values for each benchmark
    if mixed_df is not None and 'Multi-benchmark IRT' in mixed_df.columns:
        mixed_ratios = mixed_df['Train_Ratio'].values
        mixed_mse = [float(x.split(' ± ')[0]) for x in mixed_df['Multi-benchmark IRT'].values]
        plt.plot(mixed_ratios, mixed_mse, 
                label="Mixed Benchmark (Multi-IRT)", 
                color=styles["Mixed Benchmark"][0], 
                linestyle=styles["Mixed Benchmark"][1],
                linewidth=styles["Mixed Benchmark"][3], 
                marker=styles["Mixed Benchmark"][2], 
                markersize=12)
    
    # For single benchmarks, we'll use IRT-1PL results
    if ceval_df is not None and 'IRT-1PL' in ceval_df.columns:
        ceval_ratios = ceval_df['Train_Ratio'].values
        ceval_mse = [float(x.split(' ± ')[0]) for x in ceval_df['IRT-1PL'].values]
        plt.plot(ceval_ratios, ceval_mse, 
                label="CEVAL (IRT-1PL)", 
                color=styles["CEVAL"][0], 
                linestyle=styles["CEVAL"][1],
                linewidth=styles["CEVAL"][3], 
                marker=styles["CEVAL"][2], 
                markersize=12)
    
    if csqa_df is not None and 'IRT-1PL' in csqa_df.columns:
        csqa_ratios = csqa_df['Train_Ratio'].values
        csqa_mse = [float(x.split(' ± ')[0]) for x in csqa_df['IRT-1PL'].values]
        plt.plot(csqa_ratios, csqa_mse, 
                label="CSQA (IRT-1PL)", 
                color=styles["CSQA"][0], 
                linestyle=styles["CSQA"][1],
                linewidth=styles["CSQA"][3], 
                marker=styles["CSQA"][2], 
                markersize=12)
    
    if mmlu_df is not None and 'IRT-1PL' in mmlu_df.columns:
        mmlu_ratios = mmlu_df['Train_Ratio'].values
        mmlu_mse = [float(x.split(' ± ')[0]) for x in mmlu_df['IRT-1PL'].values]
        plt.plot(mmlu_ratios, mmlu_mse, 
                label="MMLU (IRT-1PL)", 
                color=styles["MMLU"][0], 
                linestyle=styles["MMLU"][1],
                linewidth=styles["MMLU"][3], 
                marker=styles["MMLU"][2], 
                markersize=12)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=18)
    plt.ylabel("MSE", fontsize=18)
    plt.title("MSE Comparison: Mixed Benchmark vs Single Benchmarks (Zoomed View)", fontsize=22, pad=25)
    plt.legend(loc="upper right", fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 0.3)  # Zoomed view
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_df is not None:
        all_ratios.update(mixed_ratios)
    if ceval_df is not None:
        all_ratios.update(ceval_ratios)
    if csqa_df is not None:
        all_ratios.update(csqa_ratios)
    if mmlu_df is not None:
        all_ratios.update(mmlu_ratios)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=16)
    plt.yticks(fontsize=16)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "single_vs_mixed_mse_comparison_zoomed.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved zoomed MSE comparison plot to: {output_path}")

def create_detailed_comparison_table(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create a detailed comparison table with all 10 train ratios
    """
    # Collect all train ratios
    all_ratios = set()
    if mixed_df is not None:
        all_ratios.update(mixed_df['Train_Ratio'].values)
    if ceval_df is not None:
        all_ratios.update(ceval_df['Train_Ratio'].values)
    if csqa_df is not None:
        all_ratios.update(csqa_df['Train_Ratio'].values)
    if mmlu_df is not None:
        all_ratios.update(mmlu_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    
    # Create comparison table
    comparison_data = []
    for ratio in sorted_ratios:
        row = {'Train_Ratio': ratio}
        
        # Add mixed benchmark data
        if mixed_df is not None and ratio in mixed_df['Train_Ratio'].values:
            mixed_row = mixed_df[mixed_df['Train_Ratio'] == ratio].iloc[0]
            if 'Multi-benchmark IRT' in mixed_row:
                row['Mixed Benchmark'] = mixed_row['Multi-benchmark IRT']
            else:
                row['Mixed Benchmark'] = "N/A"
        else:
            row['Mixed Benchmark'] = "N/A"
        
        # Add CEVAL data
        if ceval_df is not None and ratio in ceval_df['Train_Ratio'].values:
            ceval_row = ceval_df[ceval_df['Train_Ratio'] == ratio].iloc[0]
            if 'IRT-1PL' in ceval_row:
                row['CEVAL (IRT-1PL)'] = ceval_row['IRT-1PL']
            else:
                row['CEVAL (IRT-1PL)'] = "N/A"
        else:
            row['CEVAL (IRT-1PL)'] = "N/A"
        
        # Add CSQA data
        if csqa_df is not None and ratio in csqa_df['Train_Ratio'].values:
            csqa_row = csqa_df[csqa_df['Train_Ratio'] == ratio].iloc[0]
            if 'IRT-1PL' in csqa_row:
                row['CSQA (IRT-1PL)'] = csqa_row['IRT-1PL']
            else:
                row['CSQA (IRT-1PL)'] = "N/A"
        else:
            row['CSQA (IRT-1PL)'] = "N/A"
        
        # Add MMLU data
        if mmlu_df is not None and ratio in mmlu_df['Train_Ratio'].values:
            mmlu_row = mmlu_df[mmlu_df['Train_Ratio'] == ratio].iloc[0]
            if 'IRT-1PL' in mmlu_row:
                row['MMLU (IRT-1PL)'] = mmlu_row['IRT-1PL']
            else:
                row['MMLU (IRT-1PL)'] = "N/A"
        else:
            row['MMLU (IRT-1PL)'] = "N/A"
        
        comparison_data.append(row)
    
    # Create and save the comparison table
    df_comparison = pd.DataFrame(comparison_data)
    output_file = os.path.join(output_dir, "single_vs_mixed_detailed_comparison_10points.csv")
    df_comparison.to_csv(output_file, index=False)
    print(f"Saved detailed comparison table to: {output_file}")
    
    return df_comparison

def create_performance_summary(df_comparison, output_dir):
    """
    Create a performance summary at all train ratios
    """
    # Create summary data
    summary_data = []
    
    for _, row in df_comparison.iterrows():
        ratio = row['Train_Ratio']
        
        # Add each approach
        approaches = ['Mixed Benchmark', 'CEVAL (IRT-1PL)', 'CSQA (IRT-1PL)', 'MMLU (IRT-1PL)']
        for approach in approaches:
            if row[approach] != "N/A":
                summary_data.append({
                    'Approach': approach,
                    'Train_Ratio': ratio,
                    'MSE': row[approach]
                })
    
    # Create and save the summary table
    df_summary = pd.DataFrame(summary_data)
    output_file = os.path.join(output_dir, "single_vs_mixed_performance_summary_10points.csv")
    df_summary.to_csv(output_file, index=False)
    print(f"Saved performance summary to: {output_file}")
    
    return df_summary

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading MSE data for detailed comparison...")
    
    # Load MSE data for all benchmarks
    mixed_df = load_mse_data("MIXED")
    ceval_df = load_mse_data("CEVAL")
    csqa_df = load_mse_data("CSQA")
    mmlu_df = load_mse_data("MMLU")
    
    # Create detailed comparison plot
    print("Creating detailed MSE comparison plot...")
    create_detailed_comparison_plot(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir)
    
    # Create zoomed comparison plot
    print("Creating zoomed MSE comparison plot...")
    create_zoomed_comparison_plot(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir)
    
    # Create detailed comparison table
    print("Creating detailed comparison table...")
    df_comparison = create_detailed_comparison_table(mixed_df, ceval_df, csqa_df, mmlu_df, output_dir)
    
    # Create performance summary
    print("Creating performance summary...")
    create_performance_summary(df_comparison, output_dir)
    
    print(f"\nAll detailed comparison results saved to: {output_dir}")

if __name__ == "__main__":
    main()