import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from pathlib import Path

def create_multi_task_plot(
    tasks,
    mode,
    legend_name_dict,
    color_dict,
    base_path="processed_results",
    figsize=(15, 10),
    ema_alpha=0.7,
    max_steps=None,
):
    """
    Create a multi-panel plot for different tasks, showing either train or test accuracy.
    
    Parameters:
    -----------
    tasks : list
        List of task names (e.g., 'continual_cifar100', 'permuted_MNIST')
    mode : str
        'train' or 'test' to determine which accuracy to plot
    legend_name_dict : dict
        Dictionary mapping column names to legend names
    color_dict : dict
        Dictionary mapping legend names to colors
    base_path : str
        Base path to the results directory
    figsize : tuple
        Figure size (width, height)
    ema_alpha : float
        Alpha value for EMA smoothing
    max_steps : int, optional
        Maximum steps to include in the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The generated figure
    """
    num_tasks = len(tasks)
    
    if num_tasks not in [3, 4, 6]:
        raise ValueError("Number of tasks must be 3, 4, or 6")
    
    # Set up figure and grid layout based on number of tasks
    fig = plt.figure(figsize=figsize)
    
    if num_tasks == 3:
        # Single row with 3 plots
        gs = GridSpec(1, 3, wspace=0.3, figure=fig)
        axes = [fig.add_subplot(gs[0, i]) for i in range(3)]
    elif num_tasks == 4:
        # 2x2 grid
        gs = GridSpec(2, 2, wspace=0.3, hspace=0.4, figure=fig)
        axes = [fig.add_subplot(gs[i//2, i%2]) for i in range(4)]
    else:  # num_tasks == 6
        # Two rows: top with 3, bottom with 3
        gs_main = GridSpec(2, 1, height_ratios=[1, 1], hspace=0.4, figure=fig)
        gs_top = GridSpecFromSubplotSpec(1, 3, subplot_spec=gs_main[0, 0], wspace=0.3)
        gs_bottom = GridSpecFromSubplotSpec(1, 3, subplot_spec=gs_main[1, 0], wspace=0.3)
        
        axes = []
        for i in range(3):
            axes.append(fig.add_subplot(gs_top[0, i]))
        for i in range(3):
            axes.append(fig.add_subplot(gs_bottom[0, i]))
    
    all_lines = []  # for legend handles
    all_labels = []  # for legend labels
    
    # Loop through each task and its corresponding axis
    for idx, task in enumerate(tasks):
        ax = axes[idx]
        # Set background color and grid style
        ax.set_facecolor('white')
        ax.grid(True, color='#E0E0E0', linestyle='-', linewidth=0.5)
        
        # Determine file paths for mean and variance CSVs
        if mode == 'train':
            csv_file = f"{task}_train.csv"
            var_csv_file = f"{task}_train_variance.csv"
        else:  # mode == 'test'
            csv_file = f"{task}_test.csv"
            var_csv_file = f"{task}_test_variance.csv"
        
        # Read the CSV files
        csv_path = Path(base_path) / csv_file
        var_csv_path = Path(base_path) / var_csv_file
        
        if not csv_path.exists():
            print(f"Warning: File not found - {csv_path}")
            continue
            
        if not var_csv_path.exists():
            print(f"Warning: Variance file not found - {var_csv_path}")
            # Continue without variance
            has_variance = False
        else:
            has_variance = True
        
        # Read the data
        df = pd.read_csv(csv_path)
        if has_variance:
            var_df = pd.read_csv(var_csv_path)
        
        # Limit steps if specified
        if max_steps is not None:
            df = df.iloc[:max_steps]
            if has_variance:
                var_df = var_df.iloc[:max_steps]
        
        # X-axis values
        x_values = np.arange(len(df))
        
        # Plot each baseline (column) for this task
        for col in df.columns:
            if col == 'Step':  # Skip the step column
                continue
                
            # Check if this baseline is in our legend dictionary
            legend_name = legend_name_dict.get(col, col)
            
            if legend_name not in color_dict:
                print(f"Warning: No color defined for {legend_name}, skipping")
                continue
            
                
            # Get the data and apply EMA smoothing
            data = df[col]
            smoothed_data = data.ewm(alpha=(1 - ema_alpha)).mean()
            
            # Plot the line
            color = color_dict[legend_name]
            line = ax.plot(x_values, smoothed_data, color=color, linewidth=1.5)[0]
            
            # Only add to legend once
            if idx == 0 and legend_name not in all_labels:
                all_lines.append(line)
                all_labels.append(legend_name)
            
            # Add variance shading if available
            if has_variance and col in var_df.columns:
                variance = var_df[col]
                std_dev = np.sqrt(variance)  # Convert variance to std dev
                
                # Apply smoothing to std dev as well
                smoothed_std = std_dev.ewm(alpha=(1 - ema_alpha)).mean()
                
                # Add shading
                ax.fill_between(
                    x_values, 
                    smoothed_data - smoothed_std, 
                    smoothed_data + smoothed_std,
                    color=color, 
                    alpha=0.15
                )
            
        
        # Format the axis
        task_name = task.replace('_', ' ').title()
        ax.set_title(task_name, fontsize=12, pad=10)
        ax.set_xlabel('Steps', fontsize=12, labelpad=8)
        ax.set_ylabel(f'{mode.title()} Accuracy', fontsize=12, labelpad=8)
        ax.tick_params(axis='both', which='major', labelsize=10)
        
        # Set y-axis limits consistently for accuracy plots (0-1 range)
        ax.set_ylim(0, 1)
        
        # Format the spines
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1)
    
    # Add the legend
    fig.legend(
        all_lines, all_labels,
        loc='lower center',
        bbox_to_anchor=(0.5, 0.0),
        ncol=min(4, len(all_lines)),
        frameon=False,
        fontsize=12,
        handletextpad=0.8,
        columnspacing=1.0,
        labelspacing=0.2
    )
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.12)
    
    return fig

# Example usage:
tasks = ['random_MNIST','random_label_cifar10','shuffle_cifar10' ,'permuted_MNIST', 'continual_cifar100', 'continual_imagenet',]

legend_name_dict = {'Base': 'Base', 'CBP': 'CBP', 'CReLU': 'CReLU',
                    'DeepF': 'DeepF', 'EWC' : 'EWC', 'L2': 'L2', 'L2Init': 'L2Init', 'LayerNorm': 'LayerNorm',
                    'NeuroSync': 'NeuroSync', 'PReLU': 'PReLU', 'ReDo': 'ReDo', 'Scratch': 'Scratch'}

mode = 'train'

color_dict  = {
    # carried over from your previous palette
    'Base':   '#000000',   # Black
    'Scratch':    '#d62728',   # red
    'CReLU':      '#9467bd',   # purple
    'PReLU':      '#8c564b',   # brown
    'L2':         '#e377c2',   # pink
    'DeepF':      '#7f7f7f',   # gray  (same hue you used for “Deep Fourier”)
    
    # “our method” – give it the lead‑paper blue
    'NeuroSync':  '#1f77b4',   # blue
    
    # new baselines (chosen to avoid clashes with the above)
    'CBP':        '#2ca02c',   # green
    'EWC':        '#bcbd22',   # yellow‑green
    'L2Init':    '#17becf',   # teal
    'LayerNorm':  '#c49c94',   # light brown
    'ReDo':       '#ff9896',   # salmon / light red
}

fig = create_multi_task_plot(tasks, mode, legend_name_dict, color_dict)
plt.savefig(f'normal_{mode}_plot.png', dpi=300, bbox_inches='tight')
plt.show()