#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualize benchmark comparison results with publication-quality plots
"""

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Configure matplotlib for publication-quality plots
plt.rcParams.update({
    "text.usetex": True,
    "font.family": ["serif"],
    "font.serif": ["Times New Roman"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
    "figure.figsize": (10, 6),
    "lines.linewidth": 2,
    "lines.markersize": 5,
})

# Color, marker, and line style configurations
_COLOR_MAPS = {
    "blue": "#1f77b4",
    "orange": "#ff7f0e",
    "green": "#2ca02c",
    "red": "#d62728",
    "purple": "#9467bd",
    "brown": "#8c564b"
}

_MARKER_STYLES = {
    "circle": "o",
    "square": "s",
    "triangle": "^",
    "diamond": "D",
    "inverted_triangle": "v",
    "cross": "x"
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"
}

def parse_metric_value(value_str):
    """Parse metric value string like '0.880195 ± nan' to extract mean and std"""
    if '±' in value_str:
        parts = value_str.split(' ± ')
        mean_val = float(parts[0])
        std_val = float(parts[1]) if parts[1] != 'nan' else 0.0
        return mean_val, std_val
    else:
        return float(value_str), 0.0

def load_benchmark_data(file_path):
    """Load benchmark data from CSV file"""
    df = pd.read_csv(file_path)
    
    # Parse all metric columns
    metrics = {}
    for col in df.columns:
        if col != 'Train_Ratio':
            means = []
            stds = []
            for val in df[col]:
                mean_val, std_val = parse_metric_value(val)
                means.append(mean_val)
                stds.append(std_val)
            metrics[col] = {'mean': np.array(means), 'std': np.array(stds)}
    
    return df['Train_Ratio'].values, metrics

def plot_metric_comparison(benchmarks_data, metric_name, y_label, save_path):
    """Plot comparison of a specific metric across benchmarks"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Define styles for different approaches and benchmarks
    styles = [
        # Mixed Approach
        {"color": _COLOR_MAPS["blue"], "marker": _MARKER_STYLES["circle"], "linestyle": _LINE_STYLES["solid"]},
        {"color": _COLOR_MAPS["orange"], "marker": _MARKER_STYLES["square"], "linestyle": _LINE_STYLES["solid"]},
        {"color": _COLOR_MAPS["green"], "marker": _MARKER_STYLES["triangle"], "linestyle": _LINE_STYLES["solid"]},
        # Single Approach
        {"color": _COLOR_MAPS["blue"], "marker": _MARKER_STYLES["circle"], "linestyle": _LINE_STYLES["dashed"]},
        {"color": _COLOR_MAPS["orange"], "marker": _MARKER_STYLES["square"], "linestyle": _LINE_STYLES["dashed"]},
        {"color": _COLOR_MAPS["green"], "marker": _MARKER_STYLES["triangle"], "linestyle": _LINE_STYLES["dashed"]},
    ]
    
    legend_labels = []
    line_objects = []
    
    # Plot data for each benchmark
    for i, (bench_name, (train_ratios, metrics)) in enumerate(benchmarks_data.items()):
        # Mixed Approach
        mixed_key = f'Mixed Approach ({metric_name})'
        if mixed_key in metrics:
            line, = ax.plot(train_ratios, metrics[mixed_key]['mean'], 
                           color=styles[i*2]["color"], 
                           marker=styles[i*2]["marker"], 
                           linestyle=styles[i*2]["linestyle"],
                           label=f'{bench_name} - Mixed',
                           linewidth=2, markersize=6)
            ax.fill_between(train_ratios, 
                           metrics[mixed_key]['mean'] - metrics[mixed_key]['std'],
                           metrics[mixed_key]['mean'] + metrics[mixed_key]['std'],
                           color=styles[i*2]["color"], alpha=0.2)
            line_objects.append(line)
            legend_labels.append(f'{bench_name} - Mixed')
        
        # Single Approach
        single_key = f'Single Approach ({metric_name})'
        if single_key in metrics:
            line, = ax.plot(train_ratios, metrics[single_key]['mean'], 
                           color=styles[i*2+1]["color"], 
                           marker=styles[i*2+1]["marker"], 
                           linestyle=styles[i*2+1]["linestyle"],
                           label=f'{bench_name} - Single',
                           linewidth=2, markersize=6)
            ax.fill_between(train_ratios, 
                           metrics[single_key]['mean'] - metrics[single_key]['std'],
                           metrics[single_key]['mean'] + metrics[single_key]['std'],
                           color=styles[i*2+1]["color"], alpha=0.2)
            line_objects.append(line)
            legend_labels.append(f'{bench_name} - Single')
    
    # Customize plot
    ax.set_xlabel('Training Data Ratio')
    ax.set_ylabel(y_label)
    ax.set_title(f'{metric_name} Comparison: Mixed vs Single Approach Across Benchmarks')
    ax.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend with custom positioning
    ax.legend(line_objects, legend_labels, loc="upper right", frameon=True, fancybox=True, shadow=True, ncol=2)
    
    # Save plot
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def plot_approach_comparison(benchmarks_data, approach_type, save_dir):
    """Plot comparison between mixed and single approaches for all benchmarks"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Define styles for different benchmarks
    bench_styles = [
        {"color": _COLOR_MAPS["blue"], "marker": _MARKER_STYLES["circle"]},
        {"color": _COLOR_MAPS["orange"], "marker": _MARKER_STYLES["square"]},
        {"color": _COLOR_MAPS["green"], "marker": _MARKER_STYLES["triangle"]},
    ]
    
    # Plot AUC
    for i, (bench_name, (train_ratios, metrics)) in enumerate(benchmarks_data.items()):
        # Mixed Approach AUC
        mixed_auc_key = 'Mixed Approach (AUC)'
        if mixed_auc_key in metrics:
            ax1.plot(train_ratios, metrics[mixed_auc_key]['mean'], 
                    color=bench_styles[i]["color"], 
                    marker=bench_styles[i]["marker"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            ax1.fill_between(train_ratios, 
                           metrics[mixed_auc_key]['mean'] - metrics[mixed_auc_key]['std'],
                           metrics[mixed_auc_key]['mean'] + metrics[mixed_auc_key]['std'],
                           color=bench_styles[i]["color"], alpha=0.2)
        
        # Single Approach AUC
        single_auc_key = 'Single Approach (AUC)'
        if single_auc_key in metrics:
            ax1.plot(train_ratios, metrics[single_auc_key]['mean'], 
                    color=bench_styles[i]["color"], 
                    marker=bench_styles[i]["marker"], 
                    linestyle=_LINE_STYLES["dashed"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            ax1.fill_between(train_ratios, 
                           metrics[single_auc_key]['mean'] - metrics[single_auc_key]['std'],
                           metrics[single_auc_key]['mean'] + metrics[single_auc_key]['std'],
                           color=bench_styles[i]["color"], alpha=0.2)
    
    ax1.set_xlabel('Training Data Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title(f'AUC Comparison: {approach_type} Approach')
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot Accuracy
    for i, (bench_name, (train_ratios, metrics)) in enumerate(benchmarks_data.items()):
        # Mixed Approach Accuracy
        mixed_acc_key = 'Mixed Approach (Accuracy)'
        if mixed_acc_key in metrics:
            ax2.plot(train_ratios, metrics[mixed_acc_key]['mean'], 
                    color=bench_styles[i]["color"], 
                    marker=bench_styles[i]["marker"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            ax2.fill_between(train_ratios, 
                           metrics[mixed_acc_key]['mean'] - metrics[mixed_acc_key]['std'],
                           metrics[mixed_acc_key]['mean'] + metrics[mixed_acc_key]['std'],
                           color=bench_styles[i]["color"], alpha=0.2)
        
        # Single Approach Accuracy
        single_acc_key = 'Single Approach (Accuracy)'
        if single_acc_key in metrics:
            ax2.plot(train_ratios, metrics[single_acc_key]['mean'], 
                    color=bench_styles[i]["color"], 
                    marker=bench_styles[i]["marker"], 
                    linestyle=_LINE_STYLES["dashed"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            ax2.fill_between(train_ratios, 
                           metrics[single_acc_key]['mean'] - metrics[single_acc_key]['std'],
                           metrics[single_acc_key]['mean'] + metrics[single_acc_key]['std'],
                           color=bench_styles[i]["color"], alpha=0.2)
    
    ax2.set_xlabel('Training Data Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title(f'Accuracy Comparison: {approach_type} Approach')
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend
    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.95), 
              frameon=True, fancybox=True, shadow=True, ncol=3)
    
    # Save plot
    plt.tight_layout(rect=[0, 0, 1, 0.90])
    save_path = os.path.join(save_dir, f'{approach_type.lower()}_comparison.pdf')
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def main():
    # Define paths
    comparison_dir = "yourpath/comparison_results"
    save_dir = os.path.join(comparison_dir, "visualization")
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Load data for each benchmark
    benchmarks = ["ceval", "csqa", "mmlu"]
    benchmarks_data = {}
    
    for bench in benchmarks:
        file_path = os.path.join(comparison_dir, f"{bench}_metrics_mix_vs_single.csv")
        if os.path.exists(file_path):
            train_ratios, metrics = load_benchmark_data(file_path)
            benchmarks_data[bench.upper()] = (train_ratios, metrics)
            print(f"Loaded data for {bench.upper()}")
        else:
            print(f"Warning: File not found - {file_path}")
    
    # Plot individual metric comparisons
    plot_metric_comparison(benchmarks_data, "AUC", "AUC", 
                          os.path.join(save_dir, "auc_comparison.pdf"))
    plot_metric_comparison(benchmarks_data, "Accuracy", "Accuracy", 
                          os.path.join(save_dir, "accuracy_comparison.pdf"))
    
    # Plot approach comparison
    plot_approach_comparison(benchmarks_data, "Mixed vs Single", save_dir)
    
    print(f"Visualizations saved to: {save_dir}")

if __name__ == "__main__":
    main()