#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualize benchmark comparison results with publication-quality plots
Updated version: Same color for same benchmark, different markers for mixed vs single
"""

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Configure matplotlib for publication-quality plots
plt.rcParams.update({
    "text.usetex": True,
    "font.family": ["serif"],
    "font.serif": ["Times New Roman"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
    "figure.figsize": (10, 6),
    "lines.linewidth": 2,
    "lines.markersize": 5,
})

# Color, marker, and line style configurations
_COLOR_MAPS = {
    "CEVAL": "#1f77b4",    # blue
    "CSQA": "#ff7f0e",     # orange
    "MMLU": "#2ca02c"      # green
}

_MARKER_STYLES = {
    "Mixed": "o",          # circle for mixed approach
    "Single": "s"          # square for single approach
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"
}

def parse_metric_value(value_str):
    """Parse metric value string like '0.880195 ± nan' to extract mean and std"""
    if '±' in value_str:
        parts = value_str.split(' ± ')
        mean_val = float(parts[0])
        std_val = float(parts[1]) if parts[1] != 'nan' else 0.0
        return mean_val, std_val
    else:
        return float(value_str), 0.0

def load_benchmark_data(file_path):
    """Load benchmark data from CSV file"""
    df = pd.read_csv(file_path)
    
    # Parse all metric columns
    metrics = {}
    for col in df.columns:
        if col != 'Train_Ratio':
            means = []
            stds = []
            for val in df[col]:
                mean_val, std_val = parse_metric_value(val)
                means.append(mean_val)
                stds.append(std_val)
            metrics[col] = {'mean': np.array(means), 'std': np.array(stds)}
    
    return df['Train_Ratio'].values, metrics

def plot_metric_comparison(benchmarks_data, metric_name, y_label, save_path):
    """Plot comparison of a specific metric across benchmarks"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot data for each benchmark
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        color = _COLOR_MAPS.get(bench_name, "#000000")  # Default to black if not found
        
        # Mixed Approach
        mixed_key = f'Mixed Approach ({metric_name})'
        if mixed_key in metrics:
            ax.plot(train_ratios, metrics[mixed_key]['mean'], 
                   color=color, 
                   marker=_MARKER_STYLES["Mixed"], 
                   linestyle=_LINE_STYLES["solid"],
                   label=f'{bench_name} - Mixed',
                   linewidth=2, markersize=6)
            ax.fill_between(train_ratios, 
                           metrics[mixed_key]['mean'] - metrics[mixed_key]['std'],
                           metrics[mixed_key]['mean'] + metrics[mixed_key]['std'],
                           color=color, alpha=0.2)
        
        # Single Approach
        single_key = f'Single Approach ({metric_name})'
        if single_key in metrics:
            ax.plot(train_ratios, metrics[single_key]['mean'], 
                   color=color, 
                   marker=_MARKER_STYLES["Single"], 
                   linestyle=_LINE_STYLES["solid"],
                   label=f'{bench_name} - Single',
                   linewidth=2, markersize=6)
            ax.fill_between(train_ratios, 
                           metrics[single_key]['mean'] - metrics[single_key]['std'],
                           metrics[single_key]['mean'] + metrics[single_key]['std'],
                           color=color, alpha=0.2)
    
    # Customize plot
    ax.set_xlabel('Training Data Ratio')
    ax.set_ylabel(y_label)
    ax.set_title(f'{metric_name} Comparison: Mixed vs Single Approach Across Benchmarks')
    ax.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend
    ax.legend(loc="upper right", frameon=True, fancybox=True, shadow=True, ncol=2)
    
    # Save plot
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def plot_combined_metrics(benchmarks_data, save_path):
    """Plot both AUC and Accuracy in subplots"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot AUC
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach AUC
        mixed_auc_key = 'Mixed Approach (AUC)'
        if mixed_auc_key in metrics:
            ax1.plot(train_ratios, metrics[mixed_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            ax1.fill_between(train_ratios, 
                           metrics[mixed_auc_key]['mean'] - metrics[mixed_auc_key]['std'],
                           metrics[mixed_auc_key]['mean'] + metrics[mixed_auc_key]['std'],
                           color=color, alpha=0.2)
        
        # Single Approach AUC
        single_auc_key = 'Single Approach (AUC)'
        if single_auc_key in metrics:
            ax1.plot(train_ratios, metrics[single_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            ax1.fill_between(train_ratios, 
                           metrics[single_auc_key]['mean'] - metrics[single_auc_key]['std'],
                           metrics[single_auc_key]['mean'] + metrics[single_auc_key]['std'],
                           color=color, alpha=0.2)
    
    ax1.set_xlabel('Training Data Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title('AUC Comparison: Mixed vs Single Approach')
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot Accuracy
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach Accuracy
        mixed_acc_key = 'Mixed Approach (Accuracy)'
        if mixed_acc_key in metrics:
            ax2.plot(train_ratios, metrics[mixed_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            ax2.fill_between(train_ratios, 
                           metrics[mixed_acc_key]['mean'] - metrics[mixed_acc_key]['std'],
                           metrics[mixed_acc_key]['mean'] + metrics[mixed_acc_key]['std'],
                           color=color, alpha=0.2)
        
        # Single Approach Accuracy
        single_acc_key = 'Single Approach (Accuracy)'
        if single_acc_key in metrics:
            ax2.plot(train_ratios, metrics[single_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            ax2.fill_between(train_ratios, 
                           metrics[single_acc_key]['mean'] - metrics[single_acc_key]['std'],
                           metrics[single_acc_key]['mean'] + metrics[single_acc_key]['std'],
                           color=color, alpha=0.2)
    
    ax2.set_xlabel('Training Data Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy Comparison: Mixed vs Single Approach')
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend
    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.95), 
              frameon=True, fancybox=True, shadow=True, ncol=3)
    
    # Save plot
    plt.tight_layout(rect=[0, 0, 1, 0.90])
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def main():
    # Define paths
    comparison_dir = "yourpath/comparison_results"
    save_dir = os.path.join(comparison_dir, "visualization")
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Load data for each benchmark
    benchmarks = ["ceval", "csqa", "mmlu"]
    benchmarks_data = {}
    
    for bench in benchmarks:
        file_path = os.path.join(comparison_dir, f"{bench}_metrics_mix_vs_single.csv")
        if os.path.exists(file_path):
            train_ratios, metrics = load_benchmark_data(file_path)
            benchmarks_data[bench.upper()] = (train_ratios, metrics)
            print(f"Loaded data for {bench.upper()}")
        else:
            print(f"Warning: File not found - {file_path}")
    
    # Plot individual metric comparisons
    plot_metric_comparison(benchmarks_data, "AUC", "AUC", 
                          os.path.join(save_dir, "auc_comparison_updated.pdf"))
    plot_metric_comparison(benchmarks_data, "Accuracy", "Accuracy", 
                          os.path.join(save_dir, "accuracy_comparison_updated.pdf"))
    
    # Plot combined metrics
    plot_combined_metrics(benchmarks_data, 
                         os.path.join(save_dir, "combined_metrics_comparison_updated.pdf"))
    
    print(f"Updated visualizations saved to: {save_dir}")

if __name__ == "__main__":
    main()