import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os

# Set up Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def parse_value_with_error(value_str):
    """
    Parse value string with error, e.g. "0.871788 ± 0.004913"
    Returns (mean, error) tuple
    """
    if pd.isna(value_str) or value_str == '':
        return np.nan, np.nan
    
    # Use regular expression to match "mean ± error" format
    match = re.match(r'([0-9.]+)\s*±\s*([0-9.]+)', str(value_str).strip())
    if match:
        mean = float(match.group(1))
        error = float(match.group(2))
        return mean, error
    else:
        # If no error information, return value only
        try:
            mean = float(value_str)
            return mean, 0.0
        except:
            return np.nan, np.nan

def load_and_parse_csv(file_path):
    """
    Load and parse CSV file
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    df = pd.read_csv(file_path)
    
    # Create a new DataFrame to store parsed data
    parsed_data = {'Train_Ratio': df['Train_Ratio']}
    
    # Iterate through all columns except Train_Ratio
    for column in df.columns:
        if column == 'Train_Ratio':
            continue
            
        means = []
        errors = []
        
        # Parse each row's value
        for value in df[column]:
            mean, error = parse_value_with_error(value)
            means.append(mean)
            errors.append(error)
        
        # Add mean and error columns to parsed data
        parsed_data[f'{column}_mean'] = means
        parsed_data[f'{column}_error'] = errors
    
    # Convert to DataFrame
    parsed_df = pd.DataFrame(parsed_data)
    return parsed_df

def visualize_metrics(df, columns_to_plot=None, save_path=None, figsize=(15, 10)):
    """
    Visualize specified metric columns
    
    Parameters:
    df: DataFrame containing parsed data
    columns_to_plot: List of column names to plot, if None plot all columns
    save_path: Path to save the image, if None don't save
    figsize: Figure size
    """
    # Get all mean columns (excluding Train_Ratio)
    mean_columns = [col for col in df.columns if col.endswith('_mean') and col != 'Train_Ratio']
    
    # If no columns specified, plot all columns
    if columns_to_plot is None:
        columns_to_plot = mean_columns
    else:
        # Ensure column names are in correct format (add '_mean' suffix)
        columns_to_plot = [col if col.endswith('_mean') else f'{col}_mean' for col in columns_to_plot]
        # Filter out non-existent columns
        columns_to_plot = [col for col in columns_to_plot if col in mean_columns]
    
    # Set figure size
    n_plots = len(columns_to_plot)
    if n_plots == 0:
        print("No columns to plot")
        return
    
    # Calculate rows and columns to fit subplots
    n_cols = min(3, n_plots)  # Maximum 3 columns
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    # If only one subplot, convert axes to list for uniform handling
    if n_plots == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Get Train_Ratio data
    train_ratios = df['Train_Ratio']
    
    # Create chart for each specified column
    for i, column in enumerate(columns_to_plot):
        ax = axes[i]
        
        # Get corresponding error column
        error_column = column.replace('_mean', '_error')
        
        # Extract mean and error data
        means = df[column]
        errors = df[error_column]
        
        # Remove NaN values
        valid_indices = ~(np.isnan(means) | np.isnan(errors))
        valid_train_ratios = train_ratios[valid_indices]
        valid_means = means[valid_indices]
        valid_errors = errors[valid_indices]
        
        # Plot metric curve and confidence band
        ax.plot(valid_train_ratios, valid_means, marker='o', linewidth=2, markersize=6, label=column.replace('_mean', ''))
        ax.fill_between(valid_train_ratios, valid_means - valid_errors, valid_means + valid_errors, alpha=0.3)
        
        # Set chart properties
        ax.set_xlabel('Train Ratio')
        ax.set_ylabel('Metric Value')
        ax.set_title(f'{column.replace("_mean", "")}', fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
    
    # Hide extra subplots
    for i in range(n_plots, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    # Save image (if path specified)
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Chart saved to: {save_path}")
    
    plt.show()

def list_available_columns(df):
    """
    List all available column names
    """
    mean_columns = [col for col in df.columns if col.endswith('_mean') and col != 'Train_Ratio']
    original_columns = [col.replace('_mean', '') for col in mean_columns]
    print("Available columns:")
    for i, col in enumerate(original_columns, 1):
        print(f"{i:2d}. {col}")
    return original_columns

def select_columns_interactive(available_columns):
    """
    Interactive column selection
    """
    print("\nPlease select columns to plot:")
    print("1. Enter numeric indices separated by commas (e.g.: 1,3,5)")
    print("2. Enter 'all' to plot all columns")
    print("3. Enter 'range' and specify range (e.g.: 1-5)")
    print("4. Enter 'group' to select predefined groups (e.g.: mix, single, mse, auc, accuracy)")
    
    while True:
        user_input = input("\nPlease enter your choice: ").strip().lower()
        
        if user_input == 'all':
            return None  # Return None to indicate all columns selected
        
        elif '-' in user_input and user_input.replace('-', '').isdigit() == False:
            # Handle range input, e.g. "1-5"
            try:
                parts = user_input.split('-')
                start = int(parts[0].strip()) - 1
                end = int(parts[1].strip())
                selected_columns = available_columns[start:end]
                print(f"Selected range: {', '.join(selected_columns)}")
                return selected_columns
            except:
                print("Invalid range format, please try again")
                continue
                
        elif user_input == 'group':
            print("\nPredefined groups:")
            print("1. mix - Mixed Approach related metrics")
            print("2. single - Single Approach related metrics")
            print("3. mse - All MSE metrics")
            print("4. auc - All AUC metrics")
            print("5. accuracy - All Accuracy metrics")
            
            group_input = input("Please select group (e.g.: mix): ").strip().lower()
            selected_columns = []
            
            if group_input == 'mix':
                selected_columns = [col for col in available_columns if 'Mixed Approach' in col]
            elif group_input == 'single':
                selected_columns = [col for col in available_columns if 'Single Approach' in col]
            elif group_input == 'mse':
                selected_columns = [col for col in available_columns if 'MSE' in col]
            elif group_input == 'auc':
                selected_columns = [col for col in available_columns if 'AUC' in col and 'MSE' not in col]
            elif group_input == 'accuracy':
                selected_columns = [col for col in available_columns if 'Accuracy' in col and 'MSE' not in col]
            else:
                print("Invalid group selection")
                continue
                
            if selected_columns:
                print(f"Selected group '{group_input}': {', '.join(selected_columns)}")
                return selected_columns
            else:
                print("No matching columns found")
                continue
        
        else:
            # Handle numeric index input
            try:
                if ',' in user_input:
                    selected_indices = [int(x.strip()) - 1 for x in user_input.split(',')]
                else:
                    selected_indices = [int(user_input) - 1]
                
                selected_columns = [available_columns[i] for i in selected_indices if 0 <= i < len(available_columns)]
                if selected_columns:
                    print(f"Selected: {', '.join(selected_columns)}")
                    return selected_columns
                else:
                    print("No valid columns selected, please try again")
                    continue
            except:
                print("Invalid input format, please try again")
                continue

def main():
    # CSV file path
    csv_file_path = "data/ceval_metrics_mix_vs_single_rep3_improved.csv"
    
    try:
        # Load and parse data
        print("Loading and parsing data...")
        df = load_and_parse_csv(csv_file_path)
        print("Data loading completed!")
        
        # Display available columns
        available_columns = list_available_columns(df)
        
        # Interactive column selection
        selected_columns = select_columns_interactive(available_columns)
        
        # Generate visualization
        if selected_columns is None:
            print("\nGenerating visualization for all columns...")
            visualize_metrics(df, figsize=(20, 15))
        else:
            print(f"\nGenerating visualization for selected columns...")
            visualize_metrics(df, selected_columns, figsize=(15, 5 * ((len(selected_columns) + 2) // 3)))
            
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure the CSV file path is correct")
    except Exception as e:
        print(f"An error occurred: {e}")

# Non-interactive usage example
def example_usage():
    """
    Non-interactive usage example
    """
    # CSV file path
    csv_file_path = "data/ceval_metrics_mix_vs_single_rep3_improved.csv"
    
    # Load and parse data
    print("Loading and parsing data...")
    df = load_and_parse_csv(csv_file_path)
    print("Data loading completed!")
    
    # Example 1: Plot all columns
    print("\nExample 1: Plot all columns...")
    visualize_metrics(df, figsize=(20, 15))
    
    # Example 2: Plot only mixed approach related metrics
    print("\nExample 2: Plot mixed approach related metrics...")
    mix_columns = [
        'Mixed Approach (AUC)',
        'Mixed Approach (Accuracy)',
        'Single Approach (AUC)',
        'Single Approach (Accuracy)'
    ]
    visualize_metrics(df, mix_columns, 
                     save_path="results/ceval_metrics_mix_vs_single.png")
    
    # Example 3: Plot MSE related metrics
    print("\nExample 3: Plot MSE related metrics...")
    mse_columns = [
        'Global Mean (MSE)',
        'Model Mean (MSE)',
        'Question Mean (MSE)'
    ]
    visualize_metrics(df, mse_columns,
                     save_path="results/ceval_metrics_mse.png")

if __name__ == "__main__":
    # Run interactive version
    main()
    
    # If you want to run examples instead of interactive selection, you can comment out the main() above and uncomment the example_usage() below
    # example_usage()