#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualize benchmark comparison results with confidence bands
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Configure matplotlib for publication-quality plots with LaTeX rendering
plt.rcParams.update({
    "text.usetex": True,  # Enable LaTeX rendering
    "font.family": ["serif"],
    "font.serif": ["Times New Roman"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 9,  # Smaller font for legend
    "figure.dpi": 300,
    "figure.figsize": (12, 8),  # Larger figure size
    "lines.linewidth": 2,
    "lines.markersize": 5,
})

# Color, marker, and line style configurations
_COLOR_MAPS = {
    "CEVAL": "#1f77b4",    # blue
    "CSQA": "#ff7f0e",     # orange
    "MMLU": "#2ca02c"      # green
}

_MARKER_STYLES = {
    "Mixed": "o",          # circle for mixed approach
    "Single": "s"          # square for single approach
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"
}

def parse_metric_value(value_str):
    """Parse metric value string like '0.880195 ± 0.001234' to extract mean and std"""
    if '±' in value_str:
        parts = value_str.split(' ± ')
        mean_val = float(parts[0])
        std_val = float(parts[1]) if parts[1] != 'nan' else 0.0
        return mean_val, std_val
    else:
        return float(value_str), 0.0

def load_benchmark_data(file_path):
    """Load benchmark data from CSV file"""
    if not os.path.exists(file_path):
        print(f"Warning: File not found - {file_path}")
        return np.array([]), {}
    
    df = pd.read_csv(file_path)
    
    # Parse all metric columns
    metrics = {}
    for col in df.columns:
        if col != 'Train_Ratio':
            means = []
            stds = []
            for val in df[col]:
                mean_val, std_val = parse_metric_value(str(val))
                means.append(mean_val)
                stds.append(std_val)
            metrics[col] = {'mean': np.array(means), 'std': np.array(stds)}
    
    return df['Train_Ratio'].values, metrics

def plot_metric_comparison_with_bands(benchmarks_data, metric_name, y_label, save_path):
    """Plot comparison of a specific metric across benchmarks with confidence bands"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot data for each benchmark
    line_objects = []
    legend_labels = []
    
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")  # Default to black if not found
        
        # Mixed Approach
        mixed_key = f'Mixed Approach ({metric_name})'
        if mixed_key in metrics:
            # Plot the main line
            line, = ax.plot(train_ratios, metrics[mixed_key]['mean'], 
                           color=color, 
                           marker=_MARKER_STYLES["Mixed"], 
                           linestyle=_LINE_STYLES["solid"],
                           label=f'{bench_name} - Mixed',
                           linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[mixed_key]['mean'] + metrics[mixed_key]['std']
            lower_bound = metrics[mixed_key]['mean'] - metrics[mixed_key]['std']
            ax.fill_between(train_ratios, lower_bound, upper_bound, 
                           color=color, alpha=0.2)
            
            line_objects.append(line)
            legend_labels.append(f'{bench_name} - Mixed')
        
        # Single Approach
        single_key = f'Single Approach ({metric_name})'
        if single_key in metrics:
            # Plot the main line
            line, = ax.plot(train_ratios, metrics[single_key]['mean'], 
                           color=color, 
                           marker=_MARKER_STYLES["Single"], 
                           linestyle=_LINE_STYLES["solid"],
                           label=f'{bench_name} - Single',
                           linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[single_key]['mean'] + metrics[single_key]['std']
            lower_bound = metrics[single_key]['mean'] - metrics[single_key]['std']
            ax.fill_between(train_ratios, lower_bound, upper_bound, 
                           color=color, alpha=0.2, hatch='///', edgecolor=color)
            
            line_objects.append(line)
            legend_labels.append(f'{bench_name} - Single')
    
    # Customize plot
    ax.set_xlabel('Training Data Ratio')
    ax.set_ylabel(y_label)
    ax.set_title(f'{metric_name} Comparison: Mixed vs Single Approach Across Benchmarks')
    ax.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend with better positioning
    ax.legend(line_objects, legend_labels, 
              loc="upper left", bbox_to_anchor=(0, 1),
              frameon=True, fancybox=True, shadow=True, ncol=2,
              columnspacing=1.0, handlelength=1.5)
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)  # Make room for title
    
    # Save plot as PDF
    pdf_path = save_path.replace('.png', '.pdf')
    plt.savefig(pdf_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def plot_combined_metrics_with_bands(benchmarks_data, save_dir):
    """Plot both AUC and Accuracy in subplots with confidence bands for PDF"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))  # Vertical arrangement
    
    # Plot AUC
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach AUC
        mixed_auc_key = 'Mixed Approach (AUC)'
        if mixed_auc_key in metrics:
            # Plot the main line
            ax1.plot(train_ratios, metrics[mixed_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[mixed_auc_key]['mean'] + metrics[mixed_auc_key]['std']
            lower_bound = metrics[mixed_auc_key]['mean'] - metrics[mixed_auc_key]['std']
            ax1.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2)
        
        # Single Approach AUC
        single_auc_key = 'Single Approach (AUC)'
        if single_auc_key in metrics:
            # Plot the main line
            ax1.plot(train_ratios, metrics[single_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[single_auc_key]['mean'] + metrics[single_auc_key]['std']
            lower_bound = metrics[single_auc_key]['mean'] - metrics[single_auc_key]['std']
            ax1.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2, hatch='///', edgecolor=color)
    
    ax1.set_xlabel('Training Data Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title('AUC Comparison: Mixed vs Single Approach')
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot Accuracy
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach Accuracy
        mixed_acc_key = 'Mixed Approach (Accuracy)'
        if mixed_acc_key in metrics:
            # Plot the main line
            ax2.plot(train_ratios, metrics[mixed_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[mixed_acc_key]['mean'] + metrics[mixed_acc_key]['std']
            lower_bound = metrics[mixed_acc_key]['mean'] - metrics[mixed_acc_key]['std']
            ax2.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2)
        
        # Single Approach Accuracy
        single_acc_key = 'Single Approach (Accuracy)'
        if single_acc_key in metrics:
            # Plot the main line
            ax2.plot(train_ratios, metrics[single_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[single_acc_key]['mean'] + metrics[single_acc_key]['std']
            lower_bound = metrics[single_acc_key]['mean'] - metrics[single_acc_key]['std']
            ax2.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2, hatch='///', edgecolor=color)
    
    ax2.set_xlabel('Training Data Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy Comparison: Mixed vs Single Approach')
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend below the subplots
    handles, labels = ax1.get_legend_handles_labels()
    if handles and labels:
        fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), 
                  frameon=True, fancybox=True, shadow=True, ncol=3,
                  fontsize=9)
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.3)  # Make room for legend and titles
    
    # Save plot as PDF
    save_path = os.path.join(save_dir, "combined_metrics_comparison_rep3_confidence_bands.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def plot_side_by_side_metrics_with_bands(benchmarks_data, save_dir):
    """Plot both AUC and Accuracy side by side with confidence bands for PDF"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot AUC
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach AUC
        mixed_auc_key = 'Mixed Approach (AUC)'
        if mixed_auc_key in metrics:
            # Plot the main line
            ax1.plot(train_ratios, metrics[mixed_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[mixed_auc_key]['mean'] + metrics[mixed_auc_key]['std']
            lower_bound = metrics[mixed_auc_key]['mean'] - metrics[mixed_auc_key]['std']
            ax1.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2)
        
        # Single Approach AUC
        single_auc_key = 'Single Approach (AUC)'
        if single_auc_key in metrics:
            # Plot the main line
            ax1.plot(train_ratios, metrics[single_auc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[single_auc_key]['mean'] + metrics[single_auc_key]['std']
            lower_bound = metrics[single_auc_key]['mean'] - metrics[single_auc_key]['std']
            ax1.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2, hatch='///', edgecolor=color)
    
    ax1.set_xlabel('Training Data Ratio')
    ax1.set_ylabel('AUC')
    ax1.set_title('AUC Comparison: Mixed vs Single Approach')
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot Accuracy
    for bench_name, (train_ratios, metrics) in benchmarks_data.items():
        if len(train_ratios) == 0:
            continue
            
        color = _COLOR_MAPS.get(bench_name, "#000000")
        
        # Mixed Approach Accuracy
        mixed_acc_key = 'Mixed Approach (Accuracy)'
        if mixed_acc_key in metrics:
            # Plot the main line
            ax2.plot(train_ratios, metrics[mixed_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Mixed"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Mixed',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[mixed_acc_key]['mean'] + metrics[mixed_acc_key]['std']
            lower_bound = metrics[mixed_acc_key]['mean'] - metrics[mixed_acc_key]['std']
            ax2.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2)
        
        # Single Approach Accuracy
        single_acc_key = 'Single Approach (Accuracy)'
        if single_acc_key in metrics:
            # Plot the main line
            ax2.plot(train_ratios, metrics[single_acc_key]['mean'], 
                    color=color, 
                    marker=_MARKER_STYLES["Single"], 
                    linestyle=_LINE_STYLES["solid"],
                    label=f'{bench_name} - Single',
                    linewidth=2, markersize=6)
            
            # Add confidence band
            upper_bound = metrics[single_acc_key]['mean'] + metrics[single_acc_key]['std']
            lower_bound = metrics[single_acc_key]['mean'] - metrics[single_acc_key]['std']
            ax2.fill_between(train_ratios, lower_bound, upper_bound, 
                            color=color, alpha=0.2, hatch='///', edgecolor=color)
    
    ax2.set_xlabel('Training Data Ratio')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy Comparison: Mixed vs Single Approach')
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Create legend above the subplots
    handles, labels = ax1.get_legend_handles_labels()
    if handles and labels:
        fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.98), 
                  frameon=True, fancybox=True, shadow=True, ncol=3,
                  fontsize=9)
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Make room for legend
    
    # Save plot as PDF
    save_path = os.path.join(save_dir, "side_by_side_metrics_comparison_rep3_confidence_bands.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches="tight", format='pdf')
    plt.close()

def main():
    # Define paths
    comparison_dir = "yourpath/comparison_results_rep3_improved"
    save_dir = os.path.join(comparison_dir, "visualization_confidence_bands")
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Load data for each benchmark
    benchmarks = ["ceval", "csqa", "mmlu"]
    benchmarks_data = {}
    
    for bench in benchmarks:
        file_path = os.path.join(comparison_dir, f"{bench}_metrics_mix_vs_single_rep3_improved.csv")
        train_ratios, metrics = load_benchmark_data(file_path)
        if len(train_ratios) > 0:
            benchmarks_data[bench.upper()] = (train_ratios, metrics)
            print(f"Loaded data for {bench.upper()}")
        else:
            print(f"Warning: No data loaded for {bench.upper()}")
    
    # Plot individual metric comparisons with confidence bands as PDF
    plot_metric_comparison_with_bands(benchmarks_data, "AUC", "AUC", 
                                     os.path.join(save_dir, "auc_comparison_rep3_confidence_bands.png"))
    plot_metric_comparison_with_bands(benchmarks_data, "Accuracy", "Accuracy", 
                                     os.path.join(save_dir, "accuracy_comparison_rep3_confidence_bands.png"))
    
    # Plot combined metrics with confidence bands as PDF (vertical arrangement)
    plot_combined_metrics_with_bands(benchmarks_data, save_dir)
    
    # Plot side-by-side metrics with confidence bands as PDF
    plot_side_by_side_metrics_with_bands(benchmarks_data, save_dir)
    
    print(f"Confidence bands visualizations saved to: {save_dir}")

if __name__ == "__main__":
    main()