import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os

# Set up Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def parse_value_with_error(value_str):
    """
    Parse value string with error, e.g. "0.871788 ± 0.004913"
    Returns (mean, error) tuple
    """
    if pd.isna(value_str) or value_str == '':
        return np.nan, np.nan
    
    # Use regular expression to match "mean ± error" format
    match = re.match(r'([0-9.]+)\s*±\s*([0-9.]+)', str(value_str).strip())
    if match:
        mean = float(match.group(1))
        error = float(match.group(2))
        return mean, error
    else:
        # If no error information, return value only
        try:
            mean = float(value_str)
            return mean, 0.0
        except:
            return np.nan, np.nan

def load_and_parse_csv(file_path):
    """
    Load and parse CSV file
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    df = pd.read_csv(file_path)
    
    # Create a new DataFrame to store parsed data
    parsed_data = {'Train_Ratio': df['Train_Ratio']}
    
    # Iterate through all columns except Train_Ratio
    for column in df.columns:
        if column == 'Train_Ratio':
            continue
            
        means = []
        errors = []
        
        # Parse each row's value
        for value in df[column]:
            mean, error = parse_value_with_error(value)
            means.append(mean)
            errors.append(error)
        
        # Add mean and error columns to parsed data
        parsed_data[f'{column}_mean'] = means
        parsed_data[f'{column}_error'] = errors
    
    # Convert to DataFrame
    parsed_df = pd.DataFrame(parsed_data)
    return parsed_df

def visualize_metrics(df, columns_to_plot=None, save_path=None, figsize=(15, 10), title=None):
    """
    Visualize specified metric columns
    
    Parameters:
    df: DataFrame containing parsed data
    columns_to_plot: List of column names to plot, if None plot all columns
    save_path: Path to save the image, if None don't save
    figsize: Figure size
    title: Chart title
    """
    # Get all mean columns (excluding Train_Ratio)
    mean_columns = [col for col in df.columns if col.endswith('_mean') and col != 'Train_Ratio']
    
    # If no columns specified, plot all columns
    if columns_to_plot is None:
        columns_to_plot = mean_columns
    else:
        # Ensure column names are in correct format (add '_mean' suffix)
        columns_to_plot = [col if col.endswith('_mean') else f'{col}_mean' for col in columns_to_plot]
        # Filter out non-existent columns
        columns_to_plot = [col for col in columns_to_plot if col in mean_columns]
    
    # Set figure size
    n_plots = len(columns_to_plot)
    if n_plots == 0:
        print("No columns to plot")
        return
    
    # Calculate rows and columns to fit subplots
    n_cols = min(3, n_plots)  # Maximum 3 columns
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    # If only one subplot, convert axes to list for uniform handling
    if n_plots == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Get Train_Ratio data
    train_ratios = df['Train_Ratio']
    
    # Create chart for each specified column
    for i, column in enumerate(columns_to_plot):
        ax = axes[i]
        
        # Get corresponding error column
        error_column = column.replace('_mean', '_error')
        
        # Extract mean and error data
        means = df[column]
        errors = df[error_column]
        
        # Remove NaN values
        valid_indices = ~(np.isnan(means) | np.isnan(errors))
        valid_train_ratios = train_ratios[valid_indices]
        valid_means = means[valid_indices]
        valid_errors = errors[valid_indices]
        
        # Plot metric curve and confidence band
        ax.plot(valid_train_ratios, valid_means, marker='o', linewidth=2, markersize=6, label=column.replace('_mean', ''))
        ax.fill_between(valid_train_ratios, valid_means - valid_errors, valid_means + valid_errors, alpha=0.3)
        
        # Set chart properties
        ax.set_xlabel('Train Ratio')
        ax.set_ylabel('Metric Value')
        ax.set_title(f'{column.replace("_mean", "")}', fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
    
    # Hide extra subplots
    for i in range(n_plots, len(axes)):
        axes[i].set_visible(False)
    
    # Set overall title
    if title:
        fig.suptitle(title, fontsize=16)
    
    plt.tight_layout()
    
    # Save image (if path specified)
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Chart saved to: {save_path}")
    
    plt.show()

def main():
    # CSV file path
    csv_file_path = "data/ceval_metrics_mix_vs_single_rep3_improved.csv"
    
    try:
        # Load and parse data
        print("Loading and parsing data...")
        df = load_and_parse_csv(csv_file_path)
        print("Data loading completed!")
        
        # 1. Plot mixed approach vs single approach comparison
        print("\nGenerating mixed approach vs single approach comparison chart...")
        mix_vs_single_columns = [
            'Mixed Approach (AUC)',
            'Mixed Approach (Accuracy)',
            'Single Approach (AUC)',
            'Single Approach (Accuracy)'
        ]
        visualize_metrics(
            df, 
            mix_vs_single_columns, 
            save_path="results/ceval_mix_vs_single.png",
            title="CEVAL: Mixed Approach vs Single Approach"
        )
        
        # 2. Plot MSE related metrics
        print("\nGenerating MSE related metrics chart...")
        mse_columns = [
            'Global Mean (MSE)',
            'Model Mean (MSE)',
            'Question Mean (MSE)'
        ]
        visualize_metrics(
            df, 
            mse_columns, 
            save_path="results/ceval_mse_metrics.png",
            title="CEVAL: MSE Metrics"
        )
        
        # 3. Plot AUC related metrics
        print("\nGenerating AUC related metrics chart...")
        auc_columns = [
            'Mixed Approach (AUC)',
            'Single Approach (AUC)',
            'Global Mean (AUC)',
            'Model Mean (AUC)',
            'Question Mean (AUC)'
        ]
        visualize_metrics(
            df, 
            auc_columns, 
            save_path="results/ceval_auc_metrics.png",
            title="CEVAL: AUC Metrics"
        )
        
        # 4. Plot Accuracy related metrics
        print("\nGenerating Accuracy related metrics chart...")
        accuracy_columns = [
            'Mixed Approach (Accuracy)',
            'Single Approach (Accuracy)',
            'Global Mean (Accuracy)',
            'Model Mean (Accuracy)',
            'Question Mean (Accuracy)'
        ]
        visualize_metrics(
            df, 
            accuracy_columns, 
            save_path="results/ceval_accuracy_metrics.png",
            title="CEVAL: Accuracy Metrics"
        )
        
        print("\nAll charts generated successfully!")
        
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure the CSV file path is correct")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()