import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re
import argparse

path0 = os.path.dirname(os.path.abspath(__file__))

# Set global font to commonly available fonts
plt.rcParams.update({
    "font.family": "serif",
    # Use fonts that are likely to be available on most systems
    "font.serif": ["DejaVu Serif", "Liberation Serif", "Bitstream Vera Serif", "Nimbus Roman", "Times New Roman"],
    "mathtext.fontset": "dejavuserif",
    "text.usetex": False,
})

markersize = 10
capsize = 6
linewidth = 2

# Define model colors
MODEL_COLORS = {
    'HVBLL': 'red',
    'VBLL': 'darkred', 
    'MC-Dropout': 'forestgreen',
    'MDN': 'orange',
    'DVI': 'purple',
    'SWAG': 'brown',
    'PNN': 'pink',
    'BLL': 'gray',
    'Deep-GP': 'cyan',
}

def extract_sample_index(dataset_case):
    """Extract the sample index from the dataset_case string"""
    match = re.search(r'DS\d+-C\d+-S(\d+)', dataset_case)
    if match:
        return int(match.group(1))
    return None

def load_model_data(model_name, path0):
    """Load data for a specific model"""
    filename = f'results_{model_name}_comparison.csv'
    filepath = os.path.join(path0, 'result', filename)
    
    if not os.path.exists(filepath):
        print(f"Warning: File {filepath} not found. Skipping {model_name}.")
        return None
    
    data = pd.read_csv(filepath)
    data['sample_index'] = data['dataset_case'].apply(extract_sample_index)
    return data

def filter_outliers(data, threshold, metric_name):
    """Filter out outliers based on metric threshold"""
    if data is None:
        return None
    
    initial_count = len(data)
    
    # Get the appropriate column names for the metric
    train_col = f'train_{metric_name}_mean'
    test_col = f'test_{metric_name}_mean'
    
    # Check if columns exist
    if train_col not in data.columns or test_col not in data.columns:
        print(f"Warning: Columns {train_col} or {test_col} not found. Skipping filtering for {data.name if hasattr(data, 'name') else 'Model'}.")
        return data
    
    filtered_data = data[
        (data[train_col] <= threshold) & 
        (data[test_col] <= threshold)
    ]
    filtered_count = len(filtered_data)
    print(f"> {data.name if hasattr(data, 'name') else 'Model':10s}: {initial_count - filtered_count:2d} data points filtered out of {initial_count:2d}")
    
    return filtered_data

def plot_models_for_case(ax, case, model_data, model_offsets, model_colors, 
                        dataset_y, metric_name, is_first_case=False):
    """Plot all models for a specific dataset case"""
    for model_name, data in model_data.items():
        row = data[data['dataset_case'] == case]
        if not row.empty:
            # Get the appropriate column names for the metric
            train_col = f'train_{metric_name}_mean'
            test_col = f'test_{metric_name}_mean'
            train_std_col = f'train_{metric_name}_std'
            test_std_col = f'test_{metric_name}_std'
            
            # Check if the required column exists based on the axis title
            is_train_axis = 'train' in ax.get_title().lower()
            required_col = train_col if is_train_axis else test_col
            required_std_col = train_std_col if is_train_axis else test_std_col
            
            if required_col not in data.columns or required_std_col not in data.columns:
                print(f"Warning: Columns {required_col} or {required_std_col} not found for {model_name}.")
                continue
                
            value = row[train_col].values[0] if 'train' in ax.get_title().lower() else row[test_col].values[0]
            error = row[train_std_col].values[0] if 'train' in ax.get_title().lower() else row[test_std_col].values[0]
            
            color = model_colors.get(model_name, 'black')
            offset = model_offsets.get(model_name, 0.0)
            
            ax.errorbar(value, dataset_y + offset, xerr=error, 
                       fmt='o', ecolor=color, color=color, 
                       capsize=capsize, alpha=0.7, markersize=markersize,
                       linewidth=linewidth, 
                       label=model_name if is_first_case else "")
        else:
            # Model doesn't have data for this case - plot a gap marker
            color = model_colors.get(model_name, 'black')
            offset = model_offsets.get(model_name, 0.0)
            
            # Plot a small 'x' marker to indicate missing data
            ax.plot(0, dataset_y + offset, 'x', color=color, alpha=0.3, 
                   markersize=markersize*0.6, markeredgewidth=linewidth,
                   label=model_name if is_first_case else "")

def main(model_names=None, threshold=100.0, metric_name='nll'):
    # Default model names if none provided
    if model_names is None:
        model_names = ['HVBLL', 'VBLL', 'MC-Dropout']
    
    # Create plots directory
    plot_dir = os.path.join(path0, 'plots')
    os.makedirs(plot_dir, exist_ok=True)
    
    print(f"Loading data for models: {model_names}")
    
    # Load data for all models
    model_data = {}
    for model_name in model_names:
        data = load_model_data(model_name, path0)
        if data is not None:
            data.name = model_name
            model_data[model_name] = data
    
    if not model_data:
        print("Error: No valid model data found!")
        return
    
    # First, determine all dataset cases before filtering outliers
    # This ensures we don't lose cases that might be filtered out from some models
    all_small_cases = set()
    all_large_cases = set()
    
    for model_name, data in model_data.items():
        small_data = data[data['sample_index'] == 1]
        large_data = data[data['sample_index'] == 3]
        all_small_cases.update(small_data['dataset_case'].unique())
        all_large_cases.update(large_data['dataset_case'].unique())
    
    # Get common dataset_cases for all models (before filtering)
    small_dataset_cases = sorted(all_small_cases)
    large_dataset_cases = sorted(all_large_cases)
    
    # Filter out outliers based on metric threshold
    print(f"Filtering out data points with {metric_name.upper()} > {threshold}")
    
    # Filter all model data
    for model_name in model_data:
        model_data[model_name] = filter_outliers(model_data[model_name], threshold, metric_name)
    
    # Filter data for small datasets (i_sample=1) and large datasets (i_sample=3)
    model_small = {}
    model_large = {}
    
    for model_name, data in model_data.items():
        model_small[model_name] = data[data['sample_index'] == 1]
        model_large[model_name] = data[data['sample_index'] == 3]
    
    # Create a mapping from original dataset names to specific numbers (1, 5, 10, 15)
    all_cases = [re.sub(r'-S\d+$', '', case) for case in small_dataset_cases + large_dataset_cases]
    unique_cases = sorted(set(all_cases))
    
    # Instead of sequential numbering, use specific values (1, 5, 10, 15)
    selected_numbers = [1, 5, 10, 15]
    # Ensure we don't exceed the number of unique cases
    selected_numbers = [n for n in selected_numbers if n <= len(unique_cases)]
    
    # Create a mapping using only the selected numbers
    used_cases = [unique_cases[i-1] for i in selected_numbers if i-1 < len(unique_cases)]
    case_mapping = {case: str(num) for case, num in zip(used_cases, selected_numbers)}
    
    # For cases not in the selected numbers, use empty labels
    for case in unique_cases:
        if case not in used_cases:
            case_mapping[case] = ""
    
    # Extract dataset-case values without the sample index and map to specific numbers
    small_labels = [case_mapping.get(re.sub(r'-S\d+$', '', case), "") for case in small_dataset_cases]
    large_labels = [case_mapping.get(re.sub(r'-S\d+$', '', case), "") for case in large_dataset_cases]
    
    print(f'>>> Small datasets number: {len(small_dataset_cases)}')
    print(f'>>> Large datasets number: {len(large_dataset_cases)}')
    
    # Create the figure with 2x2 subplots
    fig, axs = plt.subplots(2, 2, figsize=(30, 20))
    
    # Calculate vertical offsets for each model to separate them
    n_models = len(model_names)
    model_offsets = {}
    offset_step = 0.4 / max(1, n_models - 1) if n_models > 1 else 0
    
    for i, model_name in enumerate(model_names):
        if n_models == 1:
            model_offsets[model_name] = 0.0
        else:
            model_offsets[model_name] = 0.2 - i * offset_step
    
    # Calculate initial y position and spacing for datasets
    y_spacing = 1.5  # Space between different dataset cases (reduced from 2.0)
    
    # Plot small datasets (row 0)
    current_y = len(small_dataset_cases) * y_spacing  # Start from the top
    
    for i, case in enumerate(small_dataset_cases):
        # Plot all cases, even if some models don't have data (they'll show as gaps)
        # Y position for this dataset case (center position)
        dataset_y = current_y - 0.5
        
        # Plot train metric (column 0)
        plot_models_for_case(axs[0, 0], case, model_small, model_offsets, MODEL_COLORS, 
                           dataset_y, metric_name, is_first_case=(i == 0))
        
        # Plot test metric (column 1)
        plot_models_for_case(axs[0, 1], case, model_small, model_offsets, MODEL_COLORS, 
                           dataset_y, metric_name, is_first_case=(i == 0))
        
        # Add dataset case label
        axs[0, 0].text(-0.02, dataset_y, small_labels[i], 
                     ha='right', va='center', fontsize=20, transform=axs[0, 0].get_yaxis_transform())
        axs[0, 1].text(-0.02, dataset_y, small_labels[i], 
                     ha='right', va='center', fontsize=20, transform=axs[0, 1].get_yaxis_transform())
        
        # Add horizontal line for visual separation
        if i < len(small_dataset_cases) - 1:  # Don't add line after last dataset
            axs[0, 0].axhline(y=current_y - y_spacing + 0.5, color='gray', linestyle='--', alpha=0.3)
            axs[0, 1].axhline(y=current_y - y_spacing + 0.5, color='gray', linestyle='--', alpha=0.3)
        
        # Update y position for next dataset
        current_y -= y_spacing
    
    # Plot large datasets (row 1)
    current_y = len(large_dataset_cases) * y_spacing  # Reset for large datasets
    
    for i, case in enumerate(large_dataset_cases):
        # Plot all cases, even if some models don't have data (they'll show as gaps)
        # Y position for this dataset case (center position)
        dataset_y = current_y - 0.5
        
        # Plot train metric (column 0)
        plot_models_for_case(axs[1, 0], case, model_large, model_offsets, MODEL_COLORS, 
                           dataset_y, metric_name, is_first_case=(i == 0))
        
        # Plot test metric (column 1)
        plot_models_for_case(axs[1, 1], case, model_large, model_offsets, MODEL_COLORS, 
                           dataset_y, metric_name, is_first_case=(i == 0))
        
        # Add dataset case label
        axs[1, 0].text(-0.02, dataset_y, large_labels[i], 
                     ha='right', va='center', fontsize=20, transform=axs[1, 0].get_yaxis_transform())
        axs[1, 1].text(-0.02, dataset_y, large_labels[i], 
                     ha='right', va='center', fontsize=20, transform=axs[1, 1].get_yaxis_transform())
        
        # Add horizontal line for visual separation
        if i < len(large_dataset_cases) - 1:  # Don't add line after last dataset
            axs[1, 0].axhline(y=current_y - y_spacing + 0.5, color='gray', linestyle='--', alpha=0.3)
            axs[1, 1].axhline(y=current_y - y_spacing + 0.5, color='gray', linestyle='--', alpha=0.3)
        
        # Update y position for next dataset
        current_y -= y_spacing
    
    # Set y-axis limits to show all datasets
    max_y_small = len(small_dataset_cases) * y_spacing + 1
    max_y_large = len(large_dataset_cases) * y_spacing + 1
    
    # Set titles and labels
    metric_display = metric_name.upper()
    axs[0, 0].set_title(f'Train {metric_display} (Small Datasets)', fontsize=30)
    axs[0, 1].set_title(f'Test {metric_display} (Small Datasets)', fontsize=30)
    axs[1, 0].set_title(f'Train {metric_display} (Large Datasets)', fontsize=30)
    axs[1, 1].set_title(f'Test {metric_display} (Large Datasets)', fontsize=30)
    
    # Create a single legend for the entire figure
    # Get handles and labels from the first subplot that has them
    handles, labels = axs[0, 0].get_legend_handles_labels()
    
    # Set the number of columns in the legend based on number of methods
    ncols = len(model_names)
    fig.legend(handles, labels, loc='upper center', fontsize=30, ncol=ncols, bbox_to_anchor=(0.5, 0.98))
    
    for i in range(2):
        for j in range(2):
            axs[i, j].set_xlabel(f'{metric_display} (Mean ± Std)', fontsize=30)
            axs[i, j].grid(axis='x', linestyle='--', alpha=0.7)
            
            # Remove individual legends from all subplots
            if axs[i, j].get_legend():
                axs[i, j].get_legend().remove()
            
            # Hide y-axis ticks since we're using custom labels
            axs[i, j].set_yticks([])
            
            # Add a vertical line at x=0 for reference
            axs[i, j].axvline(x=0, color='black', linestyle='-', alpha=0.2)
            
            # Increase size of x-axis tick labels
            axs[i, j].tick_params(axis='x', labelsize=20)
    
    # Set y-axis limits
    axs[0, 0].set_ylim(0, max_y_small)
    axs[0, 1].set_ylim(0, max_y_small)
    axs[1, 0].set_ylim(0, max_y_large)
    axs[1, 1].set_ylim(0, max_y_large)
    
    # Add more space at the top for the legend
    plt.subplots_adjust(wspace=0.2, hspace=0.25, top=0.9)
        
    # Save the figure with appropriate filename
    methods_str = "_".join(model_names)
    output_file = os.path.join(plot_dir, f'{metric_name.upper()}_{methods_str}')
    plt.savefig(output_file+'.png', dpi=50, bbox_inches='tight')
    plt.savefig(output_file+'.pdf', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Plot saved to {output_file}")

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Generate UCI comparison plots')
    parser.add_argument('--models', nargs='+', default=['HVBLL', 'MDN', 'MC-Dropout'],
                       help='List of model names to plot')
    parser.add_argument('--threshold', type=float, default=100.0,
                       help='Threshold for filtering outliers (default: 100.0)')
    parser.add_argument('--metric', type=str, default='nll',
                       help='Metric to plot (e.g., nll, mse, mae, crps, coverage_95, etc.) (default: nll)')
    return parser.parse_args()


if __name__ == "__main__":
    
    args = parse_arguments()
    main(model_names=args.models, threshold=args.threshold, metric_name=args.metric)

