import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def plot_errorbar_per_layer_pyz(df, projections, color_map, ci_scale=1.96, ax=None, ylabel=None, xlabel=None, labelsize=18):
    """
    Plot error bars for a single metric (either average or worst-group accuracy)
    
    Args:
        df: DataFrame with the data for one metric
        projections: List of projection methods to plot
        color_map: Dictionary mapping projection methods to colors
        ci_scale: Scale factor for confidence intervals
        ax: Matplotlib axis to plot on
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 4))
    else:
        fig = ax.figure

    # loop over the projections
    for projection in projections:
        df_selected = df[df['proj_method'] == projection]
        ax.errorbar(df_selected['p_y_z'], 
                   df_selected['mean'], 
                   yerr=df_selected['se']*ci_scale,
                   label=f'{projection}', 
                   color=color_map[projection], 
                   markersize=6, 
                   capsize=3)

    # create labels
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=labelsize)
    else:
        ax.set_xlabel(r'$p(y=a \mid z=a)$', fontsize=labelsize)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=labelsize)
    else:
        ax.set_ylabel(r'$p(y=a \mid z=a)$', fontsize=labelsize)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
    
    # set x-ticks
    min_p_y_z = df['p_y_z'].min()
    max_p_y_z = df['p_y_z'].max()
    interval = 0.1
    ax.set_xticks(np.arange(min_p_y_z, max_p_y_z+interval, interval))

    return fig, ax


def plot_barplot_per_pyz(df_avg, df_wg, projections, color_map, ci_scale=1.96, fig=None, ax=None, metric='mean', xlabel=None, ylabel1=None, ylabel2=None, text=False):
    """
    Plot bar plots for two metrics (average and worst-group accuracy) across different projections
    
    Args:
        df_wg: DataFrame with worst-group accuracy data
        df_avg: DataFrame with average accuracy data
        projections: List of projection methods to plot
        color_map: Dictionary mapping projection methods to colors
        ci_scale: Scale factor for confidence intervals
        fig: Matplotlib figure to plot on
        ax: Matplotlib axes to plot on
        metric: Column name for the metric to plot
        xlabel: X-axis label
    Returns:
        fig, ax: The figure and axes objects
    """
    # create a figure
    if fig is None and ax is None:
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    # Get unique p_y_z values and number of projections
    p_y_z_values = sorted(df_wg['p_y_z'].unique())
    num_projections = len(projections)
    bar_width = 0.8 / num_projections  # Adjust bar width based on number of projections
    
    # Track min values to set ylim
    avg_min_values = []
    wg_min_values = []
    
    # Track max values for consistent y-axis ticks
    avg_max_values = []
    wg_max_values = []
    
    # Store all bars for adding text later
    avg_bars = []
    wg_bars = []
    avg_heights = []
    wg_heights = []
    
    # Loop over the projections
    for i, projection in enumerate(projections):
        df_wg_selected = df_wg[df_wg['proj_method'] == projection]
        df_avg_selected = df_avg[df_avg['proj_method'] == projection]
        
        # Check if there is an se
        if 'se' not in df_wg_selected.columns:
            df_wg_selected['se'] = 0
        if 'se' not in df_avg_selected.columns:
            df_avg_selected['se'] = 0
        
        # Calculate bar positions (offset for each projection)
        positions = np.array(range(len(p_y_z_values)))
        # This formula centers the bars within each p_y_z group
        bar_positions = positions + (i - (num_projections-1)/2) * bar_width
        
        # Prepare data in the same order as p_y_z_values
        wg_values = []
        wg_errors = []
        avg_values = []
        avg_errors = []
        for p in p_y_z_values:
            wg_val = df_wg_selected[df_wg_selected['p_y_z'] == p][metric].values
            wg_values.append(wg_val[0] if len(wg_val) > 0 else 0)
            wg_err = df_wg_selected[df_wg_selected['p_y_z'] == p]['se'].values
            wg_errors.append(wg_err[0] * ci_scale if len(wg_err) > 0 else 0)
            
            avg_val = df_avg_selected[df_avg_selected['p_y_z'] == p][metric].values
            avg_values.append(avg_val[0] if len(avg_val) > 0 else 0)
            avg_err = df_avg_selected[df_avg_selected['p_y_z'] == p]['se'].values
            avg_errors.append(avg_err[0] * ci_scale if len(avg_err) > 0 else 0)
        
        # Track minimum values (accounting for error bars)
        wg_min_values.extend([v - e for v, e in zip(wg_values, wg_errors)])
        avg_min_values.extend([v - e for v, e in zip(avg_values, avg_errors)])
        
        # Track maximum values (accounting for error bars)
        wg_max_values.extend([v + e for v, e in zip(wg_values, wg_errors)])
        avg_max_values.extend([v + e for v, e in zip(avg_values, avg_errors)])
        
        # Create bar plots with error bars
        avg_bar = ax[0].bar(bar_positions, avg_values, width=bar_width, 
                label=f'{projection}', 
                color=color_map[projection], yerr=avg_errors, capsize=3)
        
        wg_bar = ax[1].bar(bar_positions, wg_values, width=bar_width, 
                label=f'{projection}', 
                color=color_map[projection], yerr=wg_errors, capsize=3)
        
        # Store bars and heights for adding text later
        avg_bars.append(avg_bar)
        wg_bars.append(wg_bar)
        avg_heights.append(avg_values)
        wg_heights.append(wg_values)
    
    # Set x-ticks at positions of p_y_z values
    ax[0].set_xticks(positions)
    ax[0].set_xticklabels([f'{p:.1f}' for p in p_y_z_values])
    ax[1].set_xticks(positions)
    ax[1].set_xticklabels([f'{p:.1f}' for p in p_y_z_values])
    
    # Create labels
    if xlabel is None:
        xlabel = r'$p(y=a \mid z=a)$'
    ax[0].set_xlabel(xlabel)
    ax[1].set_xlabel(xlabel)
    if ylabel1 is None:
        ylabel1 = 'Average accuracy'
    if ylabel2 is None:
        ylabel2 = 'Worst-group accuracy'
    ax[0].set_ylabel(ylabel1)
    ax[1].set_ylabel(ylabel2)
   
    # Format y-axis
    ax[0].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.1%}'))
    ax[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.1%}'))
    
    # Set y-axis limits based on minimum values
    # Round down to nearest 5% for cleaner appearance
    adj = 0.01
    if avg_min_values:
        avg_min = max(0, np.round(min(avg_min_values),2)-adj)
        avg_max = min(1, np.round(max(avg_max_values),2)+adj)
        ax[0].set_ylim(bottom=avg_min, top=avg_max)
    
    if wg_min_values:
        wg_min = max(0, np.round(min(wg_min_values),2)-adj)
        wg_max = min(1, np.round(max(wg_max_values),2)+adj)
        ax[1].set_ylim(bottom=wg_min, top=wg_max)
    
  
    # Add percentage text on top of each bar
    # Add text labels on top of each bar
    if text:
        for i in range(len(projections)):
            for j in range(len(p_y_z_values)):
                # Average accuracy bars
                height = avg_heights[i][j]
                ax[0].text(positions[j] + (i - (num_projections-1)/2) * bar_width, 
                        height + 0.01, 
                        f'{height:.1%}', 
                        ha='center', va='bottom', 
                        fontsize=8, rotation=0)
                
                # Worst-group accuracy bars
                height = wg_heights[i][j]
                ax[1].text(positions[j] + (i - (num_projections-1)/2) * bar_width, 
                        height + 0.01, 
                        f'{height:.1%}', 
                        ha='center', va='bottom', 
                        fontsize=8, rotation=0)
    
 
    
    return fig, ax

def plot_errorbar_per_pyz(df_wg, df_avg, projections, color_map, ci_scale=1.96, ax=None, xlabel=None, ylabel1=None, ylabel2=None, labelsize=18):

    if ax is None:
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    else:
        fig = ax[0].figure
   
        

    # loop over the projections
    for projection in projections:
        df_wg_selected = df_wg[df_wg['proj_method'] == projection]
        df_avg_selected = df_avg[df_avg['proj_method'] == projection]
        ax[0].errorbar(df_wg_selected['p_y_z'], 
                    df_wg_selected['mean'], 
                    yerr=df_wg_selected['se']*ci_scale,
                    label=f'{projection}', 
                    color=color_map[projection], 
                    markersize=6, 
                    capsize=3)
        ax[1].errorbar(df_avg_selected['p_y_z'], 
                    df_avg_selected['mean'], 
                    yerr=df_avg_selected['se']*ci_scale,
                    label=f'{projection}', 
                    color=color_map[projection], 
                    markersize=6, 
                    capsize=3)

        
    # create labels
    if xlabel is not None:
        ax[0].set_xlabel(xlabel, fontsize=labelsize)
        ax[1].set_xlabel(xlabel, fontsize=labelsize)
    else:
        ax[0].set_xlabel(r'$p(y=a \mid z=a)$', fontsize=labelsize)
        ax[1].set_xlabel(r'$p(y=a \mid z=a)$', fontsize=labelsize)
    if ylabel1 is not None:
        ax[0].set_ylabel(ylabel1, fontsize=labelsize)
    else:
        ax[0].set_ylabel('Average accuracy', fontsize=labelsize)
    if ylabel2 is not None:
        ax[1].set_ylabel(ylabel2, fontsize=labelsize)
    else:
        ax[1].set_ylabel('Worst-group accuracy', fontsize=labelsize)

    #forma|ztter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
    ax[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
    ax[0].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
    # set x-ticks
    min_p_y_z = df_wg['p_y_z'].min()
    max_p_y_z = df_wg['p_y_z'].max()
    interval = 0.1
    ax[0].set_xticks(np.arange(min_p_y_z, max_p_y_z+interval, interval))
    ax[1].set_xticks(np.arange(min_p_y_z, max_p_y_z+interval, interval))

   
    # return fig, ax
    return fig, ax


def select_from_val(df_test,  df_val, param, match_cols, metric, higher=True):
    """
    
    Select rows from df_test based on the values in df_val.
    Args:
        df_test: DataFrame with test results
        df_val: DataFrame with validation results
        param: Parameter to match on
        match_cols: Value to match on
    Returns:
        df_test: DataFrame with selected rows
    """



    # first, get observations from the validation set
    df_val_match = df_val[match_cols + [param, metric]].drop_duplicates()
    
    # second, per match column, get the best based on the metric
    selector = lambda x: x.loc[x[metric].idxmax() if higher else x[metric].idxmin()]
    df_val_best = df_val_match.groupby(match_cols).apply(selector).reset_index(drop=True)

    # third, per match column, get the best value for param
    df_val_selected_param = df_val_best.groupby(match_cols)[param].apply(lambda x: x.iloc[0]).reset_index()

    # fourth, per match column in test, get the value for the selected param
    df_test_selected = df_test.merge(df_val_selected_param, on=match_cols + [param], how='inner')

    # return the selected test rows
    return df_test_selected


def create_df_agg(df, metric, added_cols=None, multiply=1):

    # get test results
    cols = ['seed', 'proj_method', 'classifier_type', metric, 'p_y_z']
    if added_cols is not None:
        cols = cols + added_cols
    df_selected = df[cols]

    # set cols for the groupby
    cols_groupby = ['proj_method', 'classifier_type', 'p_y_z']
    if added_cols is not None:
        cols_groupby = cols_groupby + added_cols

    # per seed, get the mean and standard error
    df_selected_agg = df_selected.groupby(cols_groupby).agg({metric: ['mean', 'std', 'count']})
    df_selected_agg.columns = df_selected_agg.columns.droplevel()
    df_selected_agg['se'] = df_selected_agg['std'] / np.sqrt(df_selected_agg['count'])
    df_selected_agg_final = df_selected_agg.reset_index()[cols_groupby + ['mean', 'se']]
    
    # multiply the mean and se
    df_selected_agg_final['mean'] = df_selected_agg_final['mean'] * multiply
    df_selected_agg_final['se'] = df_selected_agg_final['se'] * multiply

    return df_selected_agg_final