import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import matplotlib
import os

# --- 1. Module-level Configuration ---
# This configuration is used by the create_figure function.
# It can be modified here if needed.

# Set a default font similar to the one in academic papers
matplotlib.rcParams['font.serif'] = "Times New Roman"
matplotlib.rcParams['font.family'] = "serif"
# Make default fonts larger for readability
matplotlib.rcParams['axes.titlesize'] = 18
matplotlib.rcParams['axes.labelsize'] = 16
matplotlib.rcParams['xtick.labelsize'] = 14
matplotlib.rcParams['ytick.labelsize'] = 14
matplotlib.rcParams['legend.fontsize'] = 16

# Define the subplot titles for the 3x3 grid
SUBPLOT_TITLES = [
    'ResNet-IN-1k', 'ResNet-IN-1k_v2', 'ResNet-CT-256',
    'RegNetY-IN-1k', 'RegNetY-IN-1k_v2', 'RegNetY-CT-256',
    'EfficientNetV2-IN-1k', 'EfficientNetV2-IN-1k_v2', 'EfficientNetV2-CT-256'
]

# Define the legend entries and their visual properties.
cmap = plt.cm.get_cmap('plasma', 5)
LEGEND_MODELS = [
    {'name': 'Min', 'color': cmap(0), 'marker': 's'},
    {'name': 'Max', 'color': cmap(1), 'marker': '^'},
    {'name': 'AD', 'color': cmap(2), 'marker': 'o'},
    {'name': 'Avg', 'color': cmap(3), 'marker': 'd'},
    {'name': 'Zero', 'color': cmap(4), 'marker': 'p'},
]

# Define the x-axis values. Assumed to be 6 points for "Confidence Threshold".
CONFIDENCE_THRESHOLDS = np.linspace(0.0, 1.0, 6)


# --- 2. Importable Plotting Function ---

def create_figure(
    data,
    save_path_base,
    fig_size=(22, 15),
    sharey=False,
    x_lim=(0.0, 1.0),
    y_lim=None,
    y_padding_fraction=0.05,
    title_fontsize=32,
    label_fontsize=24,
    tick_fontsize=24,
    legend_fontsize=24,
    line_width=4,
    marker_size=16,
):
    """
    Generates and saves a 3x3 plot from the provided data.

    This function is designed to be imported and used in other scripts.

    Args:
        data (dict): A dictionary where keys are subplot titles and values are
                     another dictionary mapping legend model names to lists of
                     y-values. The keys must match the titles in SUBPLOT_TITLES.
        save_path_base (str): The base path (including filename, without extension)
                              to save the figure. The function will append '.png'
                              and '.pdf' to this path.
    """
    # Create a 3x3 grid of subplots with larger figure size and optional shared y-axis
    fig, axes = plt.subplots(3, 3, figsize=fig_size, sharey=sharey)
    axes = axes.flatten()

    # --- Plot data for each subplot ---
    for i, (ax, title) in enumerate(zip(axes, SUBPLOT_TITLES)):
        ax.set_title(title, fontsize=title_fontsize, weight='bold')
        ax.grid(True, linestyle='-', alpha=0.3, color='gray')

        # Collect y-values for zooming per subplot
        subplot_y_values = []

        # Plot a line for each legend model
        for model_info in LEGEND_MODELS:
            model_name = model_info['name']
            if title in data and model_name in data[title]:
                y_values = data[title][model_name]
                subplot_y_values.extend(list(np.asarray(y_values).flatten()))

                ax.plot(CONFIDENCE_THRESHOLDS, y_values,
                        marker=model_info['marker'],
                        color=model_info['color'],
                        linestyle='-',
                        linewidth=line_width,
                        markersize=marker_size,
                        label=model_name)
            else:
                print(f"Warning: No data found for '{model_name}' in subplot '{title}'.")

        # Set shared y-axis label for the outer left plots
        if i % 3 == 0:
            ax.set_ylabel('Background Robustness', fontsize=label_fontsize)
        
        # Set x-axis label for every plot
        ax.set_xlabel('Confidence Threshold', fontsize=label_fontsize)

        # Tick label sizes
        ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)

        # Axis limits (zooming)
        if x_lim is not None:
            ax.set_xlim(x_lim)
        if y_lim is not None:
            ax.set_ylim(y_lim)
        elif not sharey and len(subplot_y_values) > 0:
            y_min = float(np.min(subplot_y_values))
            y_max = float(np.max(subplot_y_values))
            if y_max == y_min:
                pad = max(1.0, abs(y_max) * y_padding_fraction)
                ax.set_ylim(y_min - pad, y_max + pad)
            else:
                pad = (y_max - y_min) * y_padding_fraction
                ax.set_ylim(y_min - pad, y_max + pad)
    
    # --- Main Legend & Final Adjustments ---
    main_legend_elements = []
    for model_info in LEGEND_MODELS:
        el = Line2D([0], [0],
                    marker=model_info['marker'],
                    color=model_info['color'],
                    label=model_info['name'],
                    linestyle='none',
                    markersize=marker_size + 2)
        main_legend_elements.append(el)

    # Add a horizontal legend below the figure without a title
    fig.legend(handles=main_legend_elements,
               loc='lower center',
               bbox_to_anchor=(0.5, 0.01),
               ncol=len(LEGEND_MODELS),
               fontsize=legend_fontsize)

    # Adjust layout to prevent overlap and make space for the legend
    plt.tight_layout(rect=[0, 0.05, 1, 1]) # rect=[left, bottom, right, top]

    # Save the figure with high resolution
    try:
        plt.savefig(f'{save_path_base}.png', dpi=900, bbox_inches='tight', transparent=True)
        plt.savefig(f'{save_path_base}.pdf', bbox_inches='tight', transparent=True)
        print(f"Figure saved to {save_path_base}.png and {save_path_base}.pdf")
    except Exception as e:
        print(f"Error saving figure: {e}")
    
    # Close the plot to free up memory
    plt.close(fig)


# --- 3. Example Usage ---
# This block demonstrates how to use the create_figure function.
# It will only run when the script is executed directly.
if __name__ == '__main__':
    
    print("Running example usage...")

    # --- MOCK DATA GENERATION ---
    mock_data = {}
    np.random.seed(42)
    for title in SUBPLOT_TITLES:
        mock_data[title] = {}
        # Generate some plausible, decreasing data for each line
        start_value = 90 + np.random.rand() * 10
        for model in LEGEND_MODELS:
            # Create a descending trend with some noise
            y_values = start_value - np.linspace(0, 60, 6) - 5 * np.random.rand(6)
            y_values = np.clip(y_values, 0, 100) # Ensure values are within a plausible range
            mock_data[title][model['name']] = y_values
            start_value -= 10 # Stagger the start points for different models


    print(mock_data)

    # --- FUNCTION CALL ---
    # Define the base path for saving the figure
    # output_path_base = 'robustness_vs_confidence_example'
    
    # # Call the function with mock data and the desired save path
    # create_figure(data=mock_data, save_path_base=output_path_base)
    
    # print(f"Example plot generated and saved as '{output_path_base}.png/.pdf'.")
    # print("You can now import 'create_figure' from 'scientific_plot' in another script.")

