import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re

# Set up Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def parse_value_with_error(value_str):
    """
    Parse value string with error, e.g. "0.871788 ± 0.004913"
    Returns (mean, error) tuple
    """
    if pd.isna(value_str) or value_str == '':
        return np.nan, np.nan
    
    # Use regular expression to match "mean ± error" format
    match = re.match(r'([0-9.]+)\s*±\s*([0-9.]+)', str(value_str).strip())
    if match:
        mean = float(match.group(1))
        error = float(match.group(2))
        return mean, error
    else:
        # If there's no error information, return the value itself
        try:
            mean = float(value_str)
            return mean, 0.0
        except:
            return np.nan, np.nan

def load_and_parse_csv(file_path):
    """
    Load and parse CSV file
    """
    df = pd.read_csv(file_path)
    
    # Create a new DataFrame to store parsed data
    parsed_data = {'Train_Ratio': df['Train_Ratio']}
    
    # Iterate through all columns except Train_Ratio
    for column in df.columns:
        if column == 'Train_Ratio':
            continue
            
        means = []
        errors = []
        
        # Parse each row's value
        for value in df[column]:
            mean, error = parse_value_with_error(value)
            means.append(mean)
            errors.append(error)
        
        # Add mean and error columns to parsed data
        parsed_data[f'{column}_mean'] = means
        parsed_data[f'{column}_error'] = errors
    
    # Convert to DataFrame
    parsed_df = pd.DataFrame(parsed_data)
    return parsed_df

def visualize_metrics(df, columns_to_plot=None, save_path=None):
    """
    Visualize specified metric columns
    
    Parameters:
    df: DataFrame containing parsed data
    columns_to_plot: List of column names to plot, if None then plot all columns
    save_path: Path to save the image, if None then don't save
    """
    # Get all mean columns (excluding Train_Ratio)
    mean_columns = [col for col in df.columns if col.endswith('_mean') and col != 'Train_Ratio']
    
    # If no columns are specified, plot all columns
    if columns_to_plot is None:
        columns_to_plot = mean_columns
    else:
        # Ensure column names are in the correct format (add '_mean' suffix)
        columns_to_plot = [col if col.endswith('_mean') else f'{col}_mean' for col in columns_to_plot]
        # Filter out non-existent columns
        columns_to_plot = [col for col in columns_to_plot if col in mean_columns]
    
    # Set figure size
    n_plots = len(columns_to_plot)
    n_cols = 2
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    if n_plots == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Get Train_Ratio data
    train_ratios = df['Train_Ratio']
    
    # Create charts for each specified column
    for i, column in enumerate(columns_to_plot):
        ax = axes[i]
        
        # Get the corresponding error column
        error_column = column.replace('_mean', '_error')
        
        # Extract mean and error data
        means = df[column]
        errors = df[error_column]
        
        # Remove NaN values
        valid_indices = ~(np.isnan(means) | np.isnan(errors))
        valid_train_ratios = train_ratios[valid_indices]
        valid_means = means[valid_indices]
        valid_errors = errors[valid_indices]
        
        # Plot metric curves and confidence bands
        ax.plot(valid_train_ratios, valid_means, marker='o', linewidth=2, markersize=6, label=column.replace('_mean', ''))
        ax.fill_between(valid_train_ratios, valid_means - valid_errors, valid_means + valid_errors, alpha=0.3)
        
        # Set chart properties
        ax.set_xlabel('Train Ratio')
        ax.set_ylabel('Metric Value')
        ax.set_title(f'{column.replace("_mean", "")}')
        ax.grid(True, alpha=0.3)
        ax.legend()
    
    # Hide extra subplots
    for i in range(n_plots, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    # Save image (if path is specified)
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Chart saved to: {save_path}")
    
    plt.show()

def list_available_columns(df):
    """
    List all available column names
    """
    mean_columns = [col for col in df.columns if col.endswith('_mean') and col != 'Train_Ratio']
    original_columns = [col.replace('_mean', '') for col in mean_columns]
    print("Available column names:")
    for i, col in enumerate(original_columns, 1):
        print(f"{i}. {col}")
    return original_columns

def interactive_visualization(file_path):
    """
    Interactive visualization function
    """
    # Load and parse data
    print("Loading and parsing data...")
    df = load_and_parse_csv(file_path)
    print("Data loading completed!")
    
    # Display available columns
    available_columns = list_available_columns(df)
    
    # Let user select columns to plot
    print("\nPlease select columns to plot (enter numbers separated by commas, e.g.: 1,3,5):")
    print("Or enter 'all' to plot all columns:")
    
    user_input = input().strip()
    
    if user_input.lower() == 'all':
        columns_to_plot = None
        print("Will plot all columns...")
    else:
        try:
            selected_indices = [int(x.strip()) - 1 for x in user_input.split(',')]
            columns_to_plot = [available_columns[i] for i in selected_indices if 0 <= i < len(available_columns)]
            print(f"Will plot the following columns: {', '.join(columns_to_plot)}")
        except:
            print("Input format is incorrect, will plot all columns...")
            columns_to_plot = None
    
    # Generate visualization
    visualize_metrics(df, columns_to_plot)

# Main function
if __name__ == "__main__":
    # CSV file path
    csv_file_path = "data/ceval_metrics_mix_vs_single_rep3_improved.csv"
    
    # Load and parse data
    print("Loading and parsing data...")
    df = load_and_parse_csv(csv_file_path)
    print("Data loading completed!")
    
    # Display available columns
    available_columns = list_available_columns(df)
    
    # Example: Plot all columns
    print("\nGenerating visualization for all columns...")
    visualize_metrics(df, save_path="results/ceval_metrics_all.png")
    
    # Example: Plot only specific columns
    print("\nGenerating visualization for specific columns...")
    visualize_metrics(df, columns_to_plot=["mix_IRT", "single_IRT"], save_path="results/ceval_metrics_comparison.png")