import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle

from stats_utils import compute_scale_comparison_stats, compute_within_scale_comparison_stats

# Assuming your data is stored in a variable called 'data' with shape (4,5,10,5)
# data.shape = (init_scales, training_ntasks, seeds, timepoints)

timepoint_labels_6 = ['Rule', 'Stim1', 'Dly1', 'Stim2', 'Dly2', 'Resp']
timepoint_labels_10 = ['Rule', 'Stim1_e', 'Stim1_l', 'Dly1_e', 'Dly1_l', 'Stim2_e', 'Stim2_l', 'Dly2_e', 'Dly2_l', 'Resp']

def plot_epochs_results(data, init_scales, training_ntasks, ylabel, y_min, y_max, n_epochs=1000, epoch_unit=1,use_log_scale=False):
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    gs = gridspec.GridSpec(3, 2, figure=fig)
    
    # Color palette
    colors = sns.color_palette("husl", len(training_ntasks))
    
    # For each init_scale
    for i, scale in enumerate(init_scales):
        # Calculate row and column for subplot placement
        row, col = i // 2, i % 2
        
        # Create subplot
        ax = fig.add_subplot(gs[row, col])
        
        # For each training_ntask
        for j, ntask in enumerate(training_ntasks):
            # Get data for this scale and ntask across all seeds and epochs
            scale_ntask_data = data[i, j, :, :]
            
            # Calculate mean and std across seeds
            mean_data = np.mean(scale_ntask_data, axis=0)
            std_data = np.std(scale_ntask_data, axis=0)
            
            # X-axis values (epochs)
            if epoch_unit != 1:
                epochs = np.arange(0,n_epochs+1,epoch_unit)
            else:
                epochs = np.arange(n_epochs)
            
            # Plot mean line
            ax.plot(epochs, mean_data, linewidth=2.5, color=colors[j], 
                    label=f'ntasks={ntask}')
            
            if use_log_scale:
                ax.set_yscale('log')
            
            # Plot standard deviation band
            ax.fill_between(epochs, mean_data - std_data, mean_data + std_data, 
                           alpha=0.3, color=colors[j])
        
        # Customize plot appearance
        ax.set_title(f'Initial Scale = {np.round(scale,2)}', fontsize=32)
        ax.set_xlabel('Epochs', fontsize=26)
        ax.set_ylabel(ylabel, fontsize=26)  # You may want to change this label
        ax.set_ylim(y_min, y_max)
        
        # Make tick labels and axes thicker and larger
        ax.tick_params(axis='both', which='major', labelsize=24, width=2, length=6)
        for spine in ax.spines.values():
            spine.set_linewidth(2)
        
        # Only show legend in the first subplot
        if i == 4:
            ax.legend(fontsize=24, frameon=True, bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Remove top and right spines
        sns.despine(ax=ax)
    
    plt.tight_layout()

    return fig

def plot_maximal_training_epochs(data, init_scales, training_ntasks, ylabel, y_min, y_max, 
                                n_epochs=1000, epoch_unit=1, use_log_scale=False,
                                specific_scale=None, figsize=(12, 8), title=None):
    """
    Plot epoch-based results for maximal training case only, with different lines for different scales.
    
    Args:
        data: Data array with shape (n_scales, n_training_modes, n_seeds, n_epochs)
        init_scales: List of initial scale values
        training_ntasks: List of training task counts
        ylabel: Y-axis label
        y_min, y_max: Y-axis limits
        n_epochs: Number of training epochs
        epoch_unit: Epoch sampling unit (default: 1)
        use_log_scale: Whether to use log scale for y-axis
        specific_scale: If provided, plot only this scale (float or int)
        figsize: Figure size tuple (width, height)
        title: Optional plot title
    
    Returns:
        matplotlib figure object
    """
    # Find the index for 'maximal' training (assuming it's the last one with most tasks)
    maximal_idx = len(training_ntasks) - 1  # Assuming maximal is the last entry
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which scales to plot
    if specific_scale is not None:
        # Find the closest scale in the list
        scale_diffs = [abs(scale - specific_scale) for scale in init_scales]
        scale_idx = scale_diffs.index(min(scale_diffs))
        scales_to_plot = [scale_idx]
        scale_values = [init_scales[scale_idx]]
        colors = ['#1f77b4']  # Single color for single scale
    else:
        scales_to_plot = range(len(init_scales))
        scale_values = init_scales
        colors = sns.color_palette("viridis", len(init_scales))
    
    # Plot each scale
    for i, scale_idx in enumerate(scales_to_plot):
        scale = scale_values[i]
        
        # Get data for this scale and maximal training across all seeds and epochs
        scale_maximal_data = data[scale_idx, maximal_idx, :, :]
        
        # Calculate mean and std across seeds
        mean_data = np.mean(scale_maximal_data, axis=0)
        std_data = np.std(scale_maximal_data, axis=0)
        
        # X-axis values (epochs)
        if epoch_unit != 1:
            epochs = np.arange(0, n_epochs + 1, epoch_unit)
        else:
            epochs = np.arange(n_epochs)
        
        # Plot mean line
        ax.plot(epochs, mean_data, linewidth=4, color=colors[i], 
                label=f'Scale = {np.round(scale, 2)}')
        
        # Plot standard deviation band
        ax.fill_between(epochs, mean_data - std_data, mean_data + std_data, 
                       alpha=0.3, color=colors[i])
    
    # Apply log scale if requested
    if use_log_scale:
        ax.set_yscale('log')
    
    # Customize plot appearance
    if title:
        ax.set_title(title, fontsize=40, pad=20)
    else:
        if specific_scale is not None:
            ax.set_title(f'Maximal Training - Scale {np.round(specific_scale, 2)} (Epochs)', fontsize=24, pad=20)
        else:
            ax.set_title('Maximal Training - All Scales (Epochs)', fontsize=24, pad=20)
    
    ax.set_xlabel('Epochs', fontsize=32)
    ax.set_ylabel(ylabel, fontsize=32)
    ax.set_ylim(y_min, y_max)
    
    # Make tick labels and axes thicker and larger
    ax.tick_params(axis='both', which='major', labelsize=28, width=4, length=6)
    for spine in ax.spines.values():
        spine.set_linewidth(4)
    
    # Add legend
    ax.legend(fontsize=24, frameon=True, loc='best')
    
    # Remove top and right spines
    sns.despine(ax=ax)
    
    plt.tight_layout()
    
    return fig

def plot_results(data, init_scales, training_ntasks, ylabel, y_min, y_max, n_timepoints=10):
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    gs = gridspec.GridSpec(3, 2, figure=fig)
    
    # Color palette
    colors = sns.color_palette("husl", len(training_ntasks))

    if n_timepoints == 6:
        timepoint_labels = timepoint_labels_6
    elif n_timepoints == 10:
        timepoint_labels = timepoint_labels_10
    
    # For each init_scale
    for i, scale in enumerate(init_scales):
        # Calculate row and column for subplot placement
        row, col = i // 2, i % 2
        
        # Create subplot
        ax = fig.add_subplot(gs[row, col])
        
        # For each training_ntask
        for j, ntask in enumerate(training_ntasks):
            # Get data for this scale and ntask across all seeds and epochs
            scale_ntask_data = data[i, j, :, :]
            
            # Calculate mean and std across seeds
            mean_data = np.mean(scale_ntask_data, axis=0)
            std_data = np.std(scale_ntask_data, axis=0)
            
            # X-axis values (timepoints)
            timepoints = np.arange(n_timepoints)
            
            # Plot mean line
            ax.plot(timepoints, mean_data, linewidth=2.5, color=colors[j], 
                    label=f'ntasks={ntask}')
            
            # Plot standard deviation band
            ax.fill_between(timepoints, mean_data - std_data, mean_data + std_data, 
                           alpha=0.3, color=colors[j])
        
        # Customize plot appearance
        ax.set_title(f'Initial scale = {np.round(scale,2)}', fontsize=32)
        ax.set_xlabel('Timepoints', fontsize=26)
        ax.set_ylabel(ylabel, fontsize=26)  # You may want to change this label
        ax.set_ylim(y_min, y_max)
        
        ax.set_xticks(np.arange(n_timepoints))
        ax.set_xticklabels(timepoint_labels, fontsize=20,rotation=-45)
        
        # Make tick labels and axes thicker and larger
        ax.tick_params(axis='both', which='major', labelsize=24, width=2, length=6)
        for spine in ax.spines.values():
            spine.set_linewidth(2)
        
        # Only show legend in the last subplot
        if i == 4:
            ax.legend(fontsize=24, frameon=True, bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Remove top and right spines
        sns.despine(ax=ax)
        
    plt.tight_layout()

    return fig


def plot_maximal_training_results(data, init_scales, training_ntasks, ylabel, y_min, y_max, 
                                 timepoint_labels=None, n_timepoints=10, specific_scale=None,
                                 figsize=(12, 8), title=None):
    """
    Plot results for maximal training case only, with different lines for different scales.
    
    Args:
        data: Data array with shape (n_scales, n_seeds, n_timepoints) for maximal-only data
              or (n_scales, n_training_modes, n_seeds, n_timepoints) for full dataset
        init_scales: List of initial scale values
        training_ntasks: List of training task counts (used only if data contains full dataset)
        ylabel: Y-axis label
        y_min, y_max: Y-axis limits
        timepoint_labels: Labels for x-axis timepoints (optional)
        n_timepoints: Number of timepoints
        specific_scale: If provided, plot only this scale (float or int)
        figsize: Figure size tuple (width, height)
        title: Optional plot title
    
    Returns:
        matplotlib figure object
    """
    if n_timepoints == 6:
        timepoint_labels = timepoint_labels_6
    elif n_timepoints == 10:
        timepoint_labels = timepoint_labels_10
    
    # Automatically detect data format based on shape
    if len(data.shape) == 3:
        # Data shape: (n_scales, n_seeds, n_timepoints) - maximal-only format
        data_contains_only_maximal = True
    elif len(data.shape) == 4:
        # Data shape: (n_scales, n_training_modes, n_seeds, n_timepoints) - full dataset
        data_contains_only_maximal = False
        # Find the index for 'maximal' training (assuming it's the last one with most tasks)
        maximal_idx = len(training_ntasks) - 1
    else:
        raise ValueError(f"Expected data shape (n_scales, n_seeds, n_timepoints) or "
                        f"(n_scales, n_training_modes, n_seeds, n_timepoints), "
                        f"but got shape {data.shape}")
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which scales to plot
    if specific_scale is not None:
        # Find the closest scale in the list
        scale_diffs = [abs(scale - specific_scale) for scale in init_scales]
        scale_idx = scale_diffs.index(min(scale_diffs))
        scales_to_plot = [scale_idx]
        scale_values = [init_scales[scale_idx]]
        colors = ['#1f77b4']  # Single color for single scale
    else:
        scales_to_plot = range(len(init_scales))
        scale_values = init_scales
        colors = sns.color_palette("viridis", len(init_scales))
    
    # Plot each scale
    for i, scale_idx in enumerate(scales_to_plot):
        scale = scale_values[i]
        
        # Extract data based on format
        if data_contains_only_maximal:
            # Data is already maximal-only: (n_scales, n_seeds, n_timepoints)
            scale_maximal_data = data[scale_idx, :, :]
        else:
            # Extract maximal training data from full dataset
            scale_maximal_data = data[scale_idx, maximal_idx, :, :]
        
        # Calculate mean and std across seeds
        mean_data = np.mean(scale_maximal_data, axis=0)
        std_data = np.std(scale_maximal_data, axis=0)
        
        # X-axis values (timepoints)
        timepoints = np.arange(n_timepoints)
        
        # Plot mean line
        ax.plot(timepoints, mean_data, linewidth=3, color=colors[i], 
                label=f'Scale = {np.round(scale, 2)}')
        
        # Plot standard deviation band
        ax.fill_between(timepoints, mean_data - std_data, mean_data + std_data, 
                       alpha=0.3, color=colors[i])
    
    # Customize plot appearance
    if title:
        ax.set_title(title, fontsize=24, pad=20)
    else:
        if specific_scale is not None:
            ax.set_title(f'Maximal Training - Scale {np.round(specific_scale, 2)}', fontsize=24, pad=20)
        else:
            ax.set_title('Maximal Training - All Scales', fontsize=24, pad=20)
    
    ax.set_xlabel('Timepoints', fontsize=20)
    ax.set_ylabel(ylabel, fontsize=20)
    ax.set_ylim(y_min, y_max)
    
    # Set x-axis labels
    ax.set_xticks(np.arange(n_timepoints))
    
    ax.set_xticklabels(timepoint_labels, fontsize=16, rotation=-45)
    
    # Make tick labels and axes thicker and larger
    ax.tick_params(axis='both', which='major', labelsize=18, width=2, length=6)
    for spine in ax.spines.values():
        spine.set_linewidth(2)
    
    # Add legend
    ax.legend(fontsize=16, frameon=True, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Remove top and right spines
    sns.despine(ax=ax)
    
    plt.tight_layout()
    
    return fig


def plot_multi_dimensionality_comparison(*data_arrays, init_scales, training_modes,
                                        ylabel_list, label_list, color_list=None,
                                        scales_to_plot='both',  # 'both', 'rich', 'lazy'
                                        rich_scale=0.01, lazy_scale=3.0,
                                        timepoint_labels=None, n_timepoints=10,
                                        figsize=(12, 8), title=None, 
                                        use_common_scale=False,
                                        test_significance=False,
                                        alpha=0.05, correction_method='fdr',
                                        test_type='cross_scale'):
    """
    Plot multiple dimensionality measures with optional scale comparisons and statistical testing.
    
    Args:
        *data_arrays: Variable number of data arrays, each with shape (n_scales, n_seeds, n_timepoints)
        init_scales: List of initial scale values
        training_modes: List of training task counts (for compatibility)
        ylabel_list: List of y-axis labels for each data array
        label_list: List of legend labels for each data array
        color_list: List of colors for each data array (optional, will use defaults)
        scales_to_plot: 'both' (rich+lazy), 'rich' (rich only), 'lazy' (lazy only)
        rich_scale: Scale value for rich initialization
        lazy_scale: Scale value for lazy initialization  
        timepoint_labels: Labels for x-axis timepoints (optional)
        n_timepoints: Number of timepoints
        figsize: Figure size tuple (width, height)
        title: Optional plot title
        use_common_scale: If True, use single y-axis; if False, use multiple y-axes
        test_significance: If True and scales_to_plot='both', perform statistical tests
        alpha: Significance level for statistical tests
        correction_method: 'fdr' for FDR correction, 'none' for no correction
        test_type: 'cross_scale' (Rich vs Lazy, default) or 'within_scale' (data1 vs data2 within each scale)
    
    Returns:
        matplotlib figure object (with significance_results attribute if testing performed)
    """
    
    if n_timepoints == 6:
        timepoint_labels = timepoint_labels_6
    elif n_timepoints == 10:
        timepoint_labels = timepoint_labels_10
    
    # Validate inputs
    n_measures = len(data_arrays)
    if n_measures < 1 or n_measures > 3:
        raise ValueError("Function supports 1 to 3 dimensionality measures")
    
    if len(label_list) != n_measures:
        raise ValueError("label_list must have same length as number of data arrays")
    
    # Set default colors if not provided
    if color_list is None:
        default_colors = ['blue', 'red', 'green', 'orange', 'purple']
        color_list = default_colors[:n_measures]
    
    if len(color_list) != n_measures:
        raise ValueError("color_list must have same length as number of data arrays")
    
    # Find scale indices
    def find_scale_index(target_scale):
        scale_diffs = [abs(scale - target_scale) for scale in init_scales]
        return scale_diffs.index(min(scale_diffs))
    
    rich_scale_idx = find_scale_index(rich_scale)
    lazy_scale_idx = find_scale_index(lazy_scale)
    
    # Set up the figure
    fig, ax1 = plt.subplots(figsize=figsize)
    
    # X-axis values (timepoints)
    timepoints = np.arange(n_timepoints)
    
    # Prepare axis setup based on use_common_scale and number of measures
    if use_common_scale or n_measures == 1:
        # Single axis mode - all measures on same scale (or single measure)
        axes = [ax1] * n_measures
        axis_colors = ['black'] * n_measures  # All black for single axis
    else:
        # Multiple axis mode - separate y-axis for each measure
        if n_measures == 2:
            ax2 = ax1.twinx()
            axes = [ax1, ax2]
        elif n_measures == 3:
            ax2 = ax1.twinx()
            ax3 = ax1.twinx()
            # Offset the third axis
            ax3.spines['right'].set_position(('outward', 60))
            axes = [ax1, ax2, ax3]
        axis_colors = color_list  # Use individual colors for multiple axes
    
    # Store all data for y-limit calculations
    all_means = []
    all_stds = []
    lines = []
    
    # Initialize significance testing
    significance_results = None
    compare_scales = True
    
    # Process each data array
    for data_idx, data in enumerate(data_arrays):
        ax = axes[data_idx]
        color = color_list[data_idx]
        label = label_list[data_idx]
        
        # Extract data based on scales_to_plot
        if scales_to_plot in ['both', 'rich']:
            data_rich = data[rich_scale_idx, :, :]
            mean_rich = np.mean(data_rich, axis=0)
            std_rich = np.std(data_rich, axis=0)
            
            line_rich = ax.plot(timepoints, mean_rich, linewidth=4, color=color,
                               linestyle='-', label=f'{label} (Rich, scale={rich_scale})')
            ax.fill_between(timepoints, mean_rich - std_rich, mean_rich + std_rich,
                           alpha=0.3, color=color)
            lines.extend(line_rich)
            all_means.extend(mean_rich)
            all_stds.extend(std_rich)
        
        if scales_to_plot in ['both', 'lazy']:
            data_lazy = data[lazy_scale_idx, :, :]
            mean_lazy = np.mean(data_lazy, axis=0)
            std_lazy = np.std(data_lazy, axis=0)
            
            linestyle = '--' if scales_to_plot == 'both' else '--'
            alpha_fill = 0.2 if scales_to_plot == 'both' else 0.3
            
            line_lazy = ax.plot(timepoints, mean_lazy, linewidth=4, color=color,
                               linestyle=linestyle, label=f'{label} (Lazy, scale={lazy_scale})')
            ax.fill_between(timepoints, mean_lazy - std_lazy, mean_lazy + std_lazy,
                           alpha=alpha_fill, color=color)
            lines.extend(line_lazy)
            all_means.extend(mean_lazy)
            all_stds.extend(std_lazy)
        
        # Set axis properties
        if not use_common_scale and n_measures > 1:
            # Individual axis styling for multiple axis mode (only when multiple measures)
            ax.set_ylabel(ylabel_list[0], fontsize=20, color=axis_colors[data_idx])
            ax.tick_params(axis='y', labelcolor=axis_colors[data_idx], labelsize=18, width=2, length=6)
        else:
            # For common scale or single measure, only set ylabel on first axis
            if data_idx == 0:
                if n_measures == 1:
                    # Single measure - use its ylabel directly
                    ax.set_ylabel(ylabel_list[0], fontsize=32, color='black')
                else:
                    # Multiple measures with common scale - combine ylabels
                    combined_ylabel = ylabel_list[0]
                    ax.set_ylabel(combined_ylabel, fontsize=32, color='black')
                ax.tick_params(axis='y', labelcolor='black', labelsize=28, width=4, length=6)
    
    # Statistical testing if requested
    if test_significance and compare_scales:
        if test_type == 'cross_scale':
            significance_results = compute_scale_comparison_stats(
                *data_arrays, 
                rich_scale_idx=rich_scale_idx, 
                lazy_scale_idx=lazy_scale_idx,
                alpha=alpha, 
                correction_method=correction_method
            )
        elif test_type == 'within_scale':
            significance_results = compute_within_scale_comparison_stats(
                *data_arrays,
                rich_scale_idx=rich_scale_idx,
                lazy_scale_idx=lazy_scale_idx,
                alpha=alpha,
                correction_method=correction_method
            )
        else:
            raise ValueError(f"test_type must be 'cross_scale' or 'within_scale', got {test_type}")
    
    # Set y-axis limits
    if use_common_scale or n_measures == 1:
        # Single scale across all measures (or single measure)
        all_data = np.array(all_means)
        all_errors = np.array(all_stds)
        y_min = np.min(all_data - all_errors)
        y_max = np.max(all_data + all_errors)
        y_range = y_max - y_min
        y_padding = 0.05 * y_range if y_range > 0 else 0.1
        
        initial_ylim = (y_min - y_padding, y_max + y_padding)
        ax1.set_ylim(initial_ylim)
    else:
        # Individual scales for each axis (multiple measures only)
        for data_idx, data in enumerate(data_arrays):
            ax = axes[data_idx]
            
            # Calculate limits for this specific measure
            measure_means = []
            measure_stds = []
            
            if scales_to_plot in ['both', 'rich']:
                data_rich = data[rich_scale_idx, :, :]
                mean_rich = np.mean(data_rich, axis=0)
                std_rich = np.std(data_rich, axis=0)
                measure_means.extend(mean_rich)
                measure_stds.extend(std_rich)
            
            if scales_to_plot in ['both', 'lazy']:
                data_lazy = data[lazy_scale_idx, :, :]
                mean_lazy = np.mean(data_lazy, axis=0)
                std_lazy = np.std(data_lazy, axis=0)
                measure_means.extend(mean_lazy)
                measure_stds.extend(std_lazy)
            
            measure_data = np.array(measure_means)
            measure_errors = np.array(measure_stds)
            y_min = np.min(measure_data - measure_errors)
            y_max = np.max(measure_data + measure_errors)
            y_range = y_max - y_min
            y_padding = 0.05 * y_range if y_range > 0 else 0.1
            
            ax.set_ylim(y_min - y_padding, y_max + y_padding)
    
    # Add significance markers if testing was performed
    if test_significance and compare_scales and significance_results is not None:
        def get_significance_marker(p_value):
            """Convert p-value to significance marker."""
            if p_value < 0.001:
                return '***'
            elif p_value < 0.01:
                return '**'
            elif p_value < 0.05:
                return '*'
            else:
                return ''
            
        def get_within_scale_marker(p_value, scale_type):
            """Convert p-value to within-scale significance marker."""
            if scale_type == 'rich':
                if p_value < 0.001:
                    return "▲▲▲"
                elif p_value < 0.01:
                    return "▲▲"
                elif p_value < 0.05:
                    return "▲"
                else:
                    return ""
            elif scale_type == 'lazy':
                if p_value < 0.001:
                    return "●●●"
                elif p_value < 0.01:
                    return "●●"
                elif p_value < 0.05:
                    return "●"
                else:
                    return ""
        
        # Get current y-limits to position markers
        current_ylim = ax1.get_ylim()
        y_range = current_ylim[1] - current_ylim[0]
        
        # Position markers at top with spacing for multiple measures
        marker_spacing = 0.04 * y_range
        marker_start = current_ylim[1] + 0.08 * y_range
        
        if test_type == 'cross_scale':
        
            # Add significance markers for each measure
            for data_idx in range(n_measures):
                marker_y = marker_start + data_idx * marker_spacing
                color = color_list[data_idx]

                # Get significance results for this measure
                data_key = f'data_{data_idx}'
                if data_key in significance_results:
                    rejected = significance_results[data_key]['rejected']
                    corrected_p = significance_results[data_key]['corrected_p']

                    for t in range(n_timepoints):
                        if rejected[t]:
                            marker = get_significance_marker(corrected_p[t])
                            if marker:
                                ax1.text(t, marker_y, marker, ha='center', va='center',
                                        color=color, fontweight='bold', fontsize=20)
        
        elif test_type == 'within_scale':
            # New within-scale logic
            
            if scales_to_plot == 'both':
                marker_types = ['rich', 'lazy']
            elif scales_to_plot == 'rich':
                marker_types = ['rich']
            elif scales_to_plot == 'lazy':
                marker_types = ['lazy']
            
            for st,scale_type in enumerate(marker_types):
                marker_y = marker_start + st * marker_spacing
                
                if scale_type in significance_results:
                    for comparison_name, comparison_data in significance_results[scale_type]['pairwise_comparisons'].items():
                        significant_timepoints = comparison_data['significant']
                        corrected_p_values = comparison_data['corrected_p_values']
                        
                        for t, (is_sig, p_val) in enumerate(zip(significant_timepoints, corrected_p_values)):
                            if is_sig:
                                marker = get_within_scale_marker(p_val, scale_type)
                                if marker:
                                    ax1.text(t, marker_y, marker, ha='center', va='bottom', 
                                           fontsize=12, fontweight='bold')

        
        # Extend y-axis to accommodate markers
        extension = 0.15 * y_range + (n_measures - 1) * marker_spacing
        extended_y_max = current_ylim[1] + extension
        ax1.set_ylim(current_ylim[0], extended_y_max)
        
        # Update other axes if using multiple scales
        if not use_common_scale and n_measures > 1:
            for ax in axes[1:]:
                current_ax_ylim = ax.get_ylim()
                ax_y_range = current_ax_ylim[1] - current_ax_ylim[0]
                ax_extension = extension * (ax_y_range / y_range)  # Proportional scaling
                ax.set_ylim(current_ax_ylim[0], current_ax_ylim[1] + ax_extension)
    
    # Customize main axis
    ax1.set_xlabel('Timepoints', fontsize=32)
    ax1.tick_params(axis='x', labelsize=32, width=4, length=6)
    
    # Set x-axis labels
    ax1.set_xticks(np.arange(n_timepoints))
    ax1.set_xticklabels(timepoint_labels, fontsize=28, rotation=-45)
    
    # Set title
    if title:
        ax1.set_title(title, fontsize=40, pad=20)
    else:
        scale_desc = {'both': f'Rich vs Lazy ({rich_scale} vs {lazy_scale})', 
                     'rich': f'Rich Scale ({rich_scale})', 
                     'lazy': f'Lazy Scale ({lazy_scale})'}[scales_to_plot]
        
        if test_significance and compare_scales:
            correction_str = f" ({correction_method.upper()} corrected)" if correction_method == 'fdr' else ""
            ax1.set_title(f'{scale_desc} - Maximal Training{correction_str}', fontsize=24, pad=20)
        else:
            ax1.set_title(f'{scale_desc} - Maximal Training', fontsize=24, pad=20)
    
    # Create combined legend
    legend_labels = [l.get_label() for l in lines]
    
#     # Add significance legend if testing was performed
#     if test_significance and compare_scales and significance_results is not None:
#         legend_text = f'\nSignificance ({correction_method.upper()}, α={alpha}):\n'
#         pos_list = ['Bottom','Top']
#         if test_type == 'cross_scale':
#             for data_idx in range(n_measures):
#                 legend_text += f'{pos_list[data_idx]} row: {label_list[data_idx]} (rich vs lazy)\n'
#             legend_text += '* p<0.05, ** p<0.01, *** p<0.001'
#         elif test_type == 'within_scale':
#             for st,scale_type in enumerate(['rich', 'lazy']):
#                 legend_text += f'{pos_list[st]} row: {scale_type} \n'
#             legend_text += '▲/● p<0.05, ▲▲/●● p<0.01, \n▲▲▲/●●● p<0.001'
        
#         # Add dummy line for legend
#         dummy_line = ax1.plot([], [], ' ', label=legend_text)[0]
#         lines.append(dummy_line)
#         legend_labels.append(legend_text)
    
    # Position legend
    bbox_anchor = (1.25, 1) if not use_common_scale and n_measures == 3 else (1.05, 1)
    # ax1.legend(lines, legend_labels, fontsize=16, frameon=True, 
    #           bbox_to_anchor=bbox_anchor, loc='upper left')
    # ax1.legend(lines, legend_labels, fontsize=20, frameon=True, loc='best')
    
    
    # Style axis spines
    all_axes = [ax1]
    if not use_common_scale and n_measures > 1:
        all_axes.extend(axes[1:])
    
    for ax in all_axes:
        for spine in ax.spines.values():
            spine.set_linewidth(4)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    
    # Store significance results in figure for external access
    if significance_results is not None:
        fig.significance_results = significance_results
    
    return fig, ax1


def plot_heatmap(data, init_scales, training_ntasks, title):
    # Calculate mean across seeds (axis 2)
    mean_data = np.mean(data, axis=2)
    
    # Set up the figure
    plt.figure(figsize=(12, 10))
    
    # Choose a colormap that reflects differences well
    # 'viridis' is good for continuous data with good perceptual properties
    # Alternative options: 'plasma', 'magma', 'cividis', 'RdBu_r', 'coolwarm'
    cmap = 'viridis'
    
    # Create the heatmap
    ax = sns.heatmap(mean_data, annot=True, fmt='.2f', 
                    cmap=cmap, linewidths=1, linecolor='white',
                    xticklabels=training_ntasks, 
                    yticklabels=np.round(init_scales,2),
                    annot_kws={"size": 18})
    
    # Set labels and title
    plt.xlabel('Number of Training Tasks', fontsize=24, labelpad=15)
    plt.ylabel('Initial scale', fontsize=24)
    plt.title(title, fontsize=28, pad=24)
    
    # Customize tick labels
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20, rotation=0)  # Set rotation to 0 for horizontal labels
    
    # Increase tick width and length
    ax.tick_params(axis='both', which='major', width=2, length=6)
    
    # Increase axis line width
    for spine in ax.spines.values():
        spine.set_linewidth(2)
    
    # Adjust colorbar
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=20)
    cbar.outline.set_linewidth(2)
    
    # Adjust layout
    plt.tight_layout()
    
    return plt.gcf()

# Example usage
# fig = plot_heatmap(your_data, init_scales, training_ntasks)
# plt.savefig('heatmap.png', dpi=300, bbox_inches='tight')
# plt.show()

def plot_single_proximity_barplot(data, init_scales, measure_name, 
                                 rich_scale=0.01, lazy_scale=3.0,
                                 comparison='stim1_dly1',
                                 figsize=(4, 3), save_path=None):
    """
    Create a single proximity bar plot for one dimensionality measure.
    
    Args:
        data: Dimensionality array (n_scales, n_seeds, n_timepoints)
        init_scales: List of initial scale values
        measure_name: Name of the measure (e.g., 'Task', 'Stimulus')
        rich_scale: Scale value for rich initialization
        lazy_scale: Scale value for lazy initialization
        comparison: 'stim1_dly1' or 'stim2_dly2'
        figsize: Figure size tuple (width, height)
        save_path: Optional path to save the figure
    
    Returns:
        matplotlib figure object
    """
    
    # Set seaborn style
    sns.set_style("whitegrid")
    plt.rcParams.update({'font.size': 12})
    
    # Find scale indices
    def find_scale_index(target_scale):
        scale_diffs = [abs(scale - target_scale) for scale in init_scales]
        return scale_diffs.index(min(scale_diffs))
    
    rich_scale_idx = find_scale_index(rich_scale)
    lazy_scale_idx = find_scale_index(lazy_scale)
    
    # Define timepoint indices
    if comparison == 'stim1_dly1':
        stim_indices = [1, 2]  # Stim1_e, Stim1_l
        dly_indices = [3, 4]   # Dly1_e, Dly1_l
        title_suffix = 'Stim1 vs Delay1'
    else:  # stim2_dly2
        stim_indices = [5, 6]  # Stim2_e, Stim2_l
        dly_indices = [7, 8]   # Dly2_e, Dly2_l
        title_suffix = 'Stim2 vs Delay2'
    
    # Extract data
    rich_data = data[rich_scale_idx, :, :]
    lazy_data = data[lazy_scale_idx, :, :]
    
    # Calculate averages for each period
    rich_stim = np.mean(rich_data[:, stim_indices], axis=1)
    rich_dly = np.mean(rich_data[:, dly_indices], axis=1)
    lazy_stim = np.mean(lazy_data[:, stim_indices], axis=1)
    lazy_dly = np.mean(lazy_data[:, dly_indices], axis=1)
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate means and errors
    rich_stim_mean, rich_stim_err = np.mean(rich_stim), np.std(rich_stim)
    rich_dly_mean, rich_dly_err = np.mean(rich_dly), np.std(rich_dly)
    lazy_stim_mean, lazy_stim_err = np.mean(lazy_stim), np.std(lazy_stim)
    lazy_dly_mean, lazy_dly_err = np.mean(lazy_dly), np.std(lazy_dly)
    
    # Bar positions - closer together
    x_positions = [0, 0.4, 1, 1.4]  # Rich_Stim, Rich_Dly, Lazy_Stim, Lazy_Dly
    bar_means = [rich_stim_mean, rich_dly_mean, lazy_stim_mean, lazy_dly_mean]
    bar_errors = [rich_stim_err, rich_dly_err, lazy_stim_err, lazy_dly_err]
    
    # Colors: grey for rich, white for lazy
    rich_color = '#808080'  # Grey
    lazy_color = 'white'    # White
    bar_colors = [rich_color, rich_color, lazy_color, lazy_color]
    
    # Create thin bars
    bar_width = 0.3
    bars = ax.bar(x_positions, bar_means, width=bar_width, yerr=bar_errors, 
                  color=bar_colors, alpha=1.0, capsize=4,
                  edgecolor='black', linewidth=1.2)
    
    # Add connecting lines for individual seeds - black color
    for seed_idx in range(len(rich_stim)):
        # Rich condition: connect Stim to Dly
        ax.plot([0, 0.4], [rich_stim[seed_idx], rich_dly[seed_idx]], 
                color='black', alpha=0.4, linewidth=0.8, zorder=0)
        
        # Lazy condition: connect Stim to Dly  
        ax.plot([1, 1.4], [lazy_stim[seed_idx], lazy_dly[seed_idx]], 
                color='black', alpha=0.4, linewidth=0.8, zorder=0)
    
    # Set optimal y-range for this subplot only
    all_values = np.concatenate([rich_stim, rich_dly, lazy_stim, lazy_dly])
    y_min, y_max = np.min(all_values), np.max(all_values)
    y_range = y_max - y_min
    y_padding = 0.1 * y_range if y_range > 0 else 0.1
    ax.set_ylim(y_min - y_padding, y_max + y_padding)
    
    # Formatting
    ax.set_xticks([0, 0.4, 1, 1.4])
    ax.set_xticklabels(['Stim', 'Dly', 'Stim', 'Dly'], fontsize=10)
    ax.set_ylabel(f'{measure_name} Dim.', fontsize=14, fontweight='bold')
    ax.set_title(f'{measure_name}: {title_suffix}', fontsize=16, fontweight='bold', pad=15)
    ax.tick_params(axis='y', labelsize=12)
    
    # Seaborn despine
    sns.despine(ax=ax)
    
    # Add legend (only on first plot or as requested)
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor=rich_color, edgecolor='black', label='Rich'),
        Patch(facecolor=lazy_color, edgecolor='black', label='Lazy')
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.05,1), loc='upper left', fontsize=12, frameon=True)
    
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Figure saved to: {save_path}")
    
    return fig


def plot_all_proximity_barplots(task_data, stimulus_data, motor_data, rule_data,
                               init_scales, rich_scale=0.01, lazy_scale=3.0,
                               save_dir=None):
    """
    Generate all 8 individual proximity bar plots for manual assembly.
    
    Args:
        task_data: Task dimensionality array (n_scales, n_seeds, n_timepoints)
        stimulus_data: Stimulus dimensionality array (n_scales, n_seeds, n_timepoints)  
        motor_data: Motor response dimensionality array (n_scales, n_seeds, n_timepoints)
        rule_data: Rule satisfaction dimensionality array (n_scales, n_seeds, n_timepoints)
        init_scales: List of initial scale values
        rich_scale: Scale value for rich initialization
        lazy_scale: Scale value for lazy initialization
        save_dir: Directory to save individual plots
    
    Returns:
        Dictionary of matplotlib figure objects
    """
    
    data_arrays = [task_data, stimulus_data, motor_data, rule_data]
    measure_names = ['Task', 'Stimulus', 'Response', 'Rule Satis.']
    
    figures = {}
    
    for measure_name, data in zip(measure_names, data_arrays):
        for comparison in ['stim1_dly1', 'stim2_dly2']:
            
            # Create save path if directory provided
            save_path = None
            if save_dir:
                filename = f"{measure_name.lower().replace(' ', '_').replace('.', '')}_{comparison}.png"
                save_path = f"{save_dir}/{filename}"
            
            # Create the plot
            fig = plot_single_proximity_barplot(
                data, init_scales, measure_name,
                rich_scale=rich_scale, lazy_scale=lazy_scale,
                comparison=comparison, figsize=(4, 3),
                save_path=save_path
            )
            
            figures[f"{measure_name}_{comparison}"] = fig
    
    return figures


# # Example usage
# def example_usage():
#     """Example of how to use the proximity bar plots functions."""
    
#     # Create synthetic data
#     np.random.seed(42)
#     n_scales, n_seeds, n_timepoints = 5, 10, 10
    
#     task_data = np.random.randn(n_scales, n_seeds, n_timepoints) + 5
#     stimulus_data = np.random.randn(n_scales, n_seeds, n_timepoints) + 3
#     motor_data = np.random.randn(n_scales, n_seeds, n_timepoints) + 2
#     rule_data = np.random.randn(n_scales, n_seeds, n_timepoints) + 1.5
    
#     init_scales = [0.01, 0.1, 1.0, 3.0, 10.0]
    
#     # Option 1: Individual plots for manual assembly
#     print("Creating individual plots...")
#     figures = plot_all_proximity_barplots(
#         task_data, stimulus_data, motor_data, rule_data,
#         init_scales=init_scales, save_dir='./proximity_plots'
#     )
    
#     # Option 2: Comprehensive figure
#     print("Creating comprehensive figure...")
#     comprehensive_fig = plot_comprehensive_proximity_barplots(
#         task_data, stimulus_data, motor_data, rule_data,
#         init_scales=init_scales, save_path='comprehensive_proximity.png'
#     )
    
#     plt.show()
#     return figures, comprehensive_fig

# # Uncomment to test:
# # figures, comprehensive_fig = example_usage()