#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualize benchmark comparison results with publication-quality plots for rep3 data
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Configure matplotlib for publication-quality plots
plt.rcParams.update({
    "text.usetex": False,  # Disable LaTeX for compatibility
    "font.family": ["Arial", "Helvetica"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
    "figure.figsize": (10, 6),
    "lines.linewidth": 2,
    "lines.markersize": 5,
})

# Color, marker, and line style configurations
_COLOR_MAPS = {
    "CEVAL": "#1f77b4",    # blue
    "CSQA": "#ff7f0e",     # orange
    "MMLU": "#2ca02c"      # green
}

_MARKER_STYLES = {
    "Mixed": "o",          # circle for mixed approach
    "Single": "s"          # square for single approach
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"
}

def parse_metric_value(value_str):
    """Parse metric value string like '0.880195 ± 0.001234' to extract mean and std"""
    if '±' in value_str:
        parts = value_str.split(' ± ')
        mean_val = float(parts[0])
        std_val = float(parts[1]) if parts[1] != 'nan' else 0.0
        return mean_val, std_val
    else:
        return float(value_str), 0.0

def load_benchmark_data(file_path):
    """Load benchmark data from CSV file"""
    if not os.path.exists(file_path):
        print(f"Warning: File not found - {file_path}")
        return np.array([]), {}
    
    df = pd.read_csv(file_path)
    
    # Parse all metric columns
    metrics = {}
    for col in df.columns:
        if col != 'Train_Ratio':
            means = []
            stds = []
            for val in df[col]:
                mean_val, std_val = parse_metric_value(str(val))
                means.append(mean_val)
                stds.append(std_val)
            metrics[col] = {'mean': np.array(means), 'std': np.array(stds)}
    
    return df['Train_Ratio'].values, metrics

def plot_metric_comparison(benchmarks_data, metric_name, y_label, save_path):
    """Plot comparison of a specific metric across benchmarks"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot data for each benchmark
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")  # Default to black if not found
        
        # Mixed Approach
        mixed_key = f'Mixed Approach ({metric_name})'
        if mixed_key in metrics:
            ax.errorbar(train_ratios, metrics[mixed_key]['mean'], 
                       yerr=metrics[mixed_key]['std'],
                       color=color, 
                       marker=_MARKER_STYLES["Mixed"], 
                       linestyle=_LINE_STYLES["solid"],
                       label=f'{bench_name} - Mixed',
                       linewidth=2, markersize=6, capsize=3)
        
        # Single Approach
        single_key = f'Single Approach ({metric_name})'
        if single_key in metrics:
            ax.errorbar(train_ratios, metrics[single_key]['mean'], 
                       yerr=metrics[single_key]['std'],
                       color=color, 
                       marker=_MARKER_STYLES["Single"], 
                       linestyle=_LINE_STYLES["solid"],
                       label=f'{bench_name} - Single',
                       linewidth=2, markersize=6, capsize=3)
    
    # Customize plot
    ax.set_xlabel('Training Data Ratio')
    ax.set_ylabel(y_label)
    ax.set_title(f'{metric_name} Comparison: Mixed vs Single Approach Across Benchmarks')
    ax.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend
    ax.legend(loc="upper right", frameon=True, fancybox=True, shadow=True, ncol=2)
    
    # Save plot
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()

def plot_combined_metrics(benchmarks_data, save_dir):
    """Plot both AUC and Accuracy in subplots"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot AUC
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach AUC
        mixed_auc_key = 'Mixed Approach (AUC)'
        if mixed_auc_key in metrics:
            ax1.errorbar(train_ratios, metrics[mixed_auc_key]['mean'], 
                        yerr=metrics[mixed_auc_key]['std'],
                        color=color, 
                        marker=_MARKER_STYLES["Mixed"], 
                        linestyle=_LINE_STYLES["solid"],
                        label=f'{bench_name} - Mixed',
                        linewidth=2, markersize=6, capsize=3)
        
        # Single Approach AUC
        single_auc_key = 'Single Approach (AUC)'
        if single_auc_key in metrics:
            ax1.errorbar(train_ratios, metrics[single_auc_key]['mean'], 
                        yerr=metrics[single_auc_key]['std'],
                        color=color, 
                        marker=_MARKER_STYLES["Single"], 
                        linestyle=_LINE_STYLES["solid"],
                        label=f'{bench_name} - Single',
                        linewidth=2, markersize=6, capsize=3)
    
    ax1.set_xlabel('Training Data Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title('AUC Comparison: Mixed vs Single Approach')
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot Accuracy
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach Accuracy
        mixed_acc_key = 'Mixed Approach (Accuracy)'
        if mixed_acc_key in metrics:
            ax2.errorbar(train_ratios, metrics[mixed_acc_key]['mean'], 
                        yerr=metrics[mixed_acc_key]['std'],
                        color=color, 
                        marker=_MARKER_STYLES["Mixed"], 
                        linestyle=_LINE_STYLES["solid"],
                        label=f'{bench_name} - Mixed',
                        linewidth=2, markersize=6, capsize=3)
        
        # Single Approach Accuracy
        single_acc_key = 'Single Approach (Accuracy)'
        if single_acc_key in metrics:
            ax2.errorbar(train_ratios, metrics[single_acc_key]['mean'], 
                        yerr=metrics[single_acc_key]['std'],
                        color=color, 
                        marker=_MARKER_STYLES["Single"], 
                        linestyle=_LINE_STYLES["solid"],
                        label=f'{bench_name} - Single',
                        linewidth=2, markersize=6, capsize=3)
    
    ax2.set_xlabel('Training Data Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy Comparison: Mixed vs Single Approach')
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend
    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.95), 
              frameon=True, fancybox=True, shadow=True, ncol=3)
    
    # Save plot
    plt.tight_layout(rect=[0, 0, 1, 0.90])
    save_path = os.path.join(save_dir, "combined_metrics_comparison_rep3.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()

def main():
    # Define paths
    comparison_dir = "yourpath/comparison_results_rep3_improved"
    save_dir = os.path.join(comparison_dir, "visualization")
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Load data for each benchmark
    benchmarks = ["ceval", "csqa", "mmlu"]
    benchmarks_data = {}
    
    for bench in benchmarks:
        file_path = os.path.join(comparison_dir, f"{bench}_metrics_mix_vs_single_rep3_improved.csv")
        train_ratios, metrics = load_benchmark_data(file_path)
        if len(train_ratios) > 0:
            benchmarks_data[bench.upper()] = (train_ratios, metrics)
            print(f"Loaded data for {bench.upper()}")
        else:
            print(f"Warning: No data loaded for {bench.upper()}")
    
    # Plot individual metric comparisons
    plot_metric_comparison(benchmarks_data, "AUC", "AUC", 
                          os.path.join(save_dir, "auc_comparison_rep3.png"))
    plot_metric_comparison(benchmarks_data, "Accuracy", "Accuracy", 
                          os.path.join(save_dir, "accuracy_comparison_rep3.png"))
    
    # Plot combined metrics
    plot_combined_metrics(benchmarks_data, save_dir)
    
    print(f"Visualizations saved to: {save_dir}")

if __name__ == "__main__":
    main()