# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests

OKABE = ["#0072B2", "#E69F00"]  # X, Y

plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "font.size": 14,              # Increased base font size
    "axes.labelsize": 12,
    "axes.titlesize": 12,
    "xtick.labelsize": 10,        # Increased X-tick label font size
    "ytick.labelsize": 10,        # Increased Y-tick label font size
    "legend.fontsize": 10,        # Increased legend font size
    "lines.linewidth": 1.2,
    "axes.linewidth": 0.8,
    "axes.grid": True,
    "axes.grid.axis": "y",
    "grid.linestyle": "--",
    "grid.alpha": 0.35,
})

def ci95(a):
    a = np.asarray(a, float)
    sem = a.std(ddof=1) / np.sqrt(len(a))
    return 1.96 * sem

def star(p):
    if p < 1e-3: return "***"
    if p < 1e-2: return "**"
    if p < 5e-2: return "*"
    return "ns"

def bar_tops(means, cis):
    return np.array(means) + np.array(cis)

def add_stars_over_bars(ax, xs, tops, pvals, star_vertical_offset=0.01, color='black'): # Added color parameter
    """Place stars above bars."""
    for x, top, p in zip(xs, tops, pvals):
        y_text_baseline = top + star_vertical_offset
        s = star(p)
        if s == "ns":
            continue
        
        ax.text(x, y_text_baseline, s, ha="center", va="bottom", fontsize=11, clip_on=False, color=color) # Use specified color


def barplot_vs_baseline(decoders, X_list, Y_list, baseline=0, outfile="figure.pdf",
                        show_table=False, second_baseline_name=None, second_star_color='red'): # Changed default color to 'red'
    """Bracket-free grouped bars with stars; p-values in a side table."""
    # means & CI
    means_x = [np.mean(v) for v in X_list]
    means_y = [np.mean(v) for v in Y_list]
    ci_x    = [ci95(v)     for v in X_list]
    ci_y    = [ci95(v)     for v in Y_list]

    # Calculate overall max value needed, including space for stars
    all_bar_tops_with_ci = []
    if means_x and ci_x:
        all_bar_tops_with_ci.extend([m + c for m, c in zip(means_x, ci_x)])
    if means_y and ci_y:
        all_bar_tops_with_ci.extend([m + c for m, c in zip(means_y, ci_y)])

    max_data_value = max(all_bar_tops_with_ci) if all_bar_tops_with_ci else 0

    # Determine the Y-axis top limit based on max data value plus a fixed buffer for stars
    y_axis_top_buffer = 0.3 # Increased buffer to ensure stars fit above tallest bars

    # Paired tests vs primary baseline
    p_raw_x = [np.nan] + [ttest_ind(v, X_list[baseline], equal_var=False).pvalue for v in X_list[1:]]
    p_raw_y = [np.nan] + [ttest_ind(v, Y_list[baseline], equal_var=False).pvalue for v in Y_list[1:]]

    # BH/FDR correction across all tests for primary baseline
    valid_p_raw_x = [p for p in p_raw_x[1:] if not np.isnan(p)]
    valid_p_raw_y = [p for p in p_raw_y[1:] if not np.isnan(p)]

    corrected_x = multipletests(valid_p_raw_x, method="fdr_bh")[1] if valid_p_raw_x else []
    corrected_y = multipletests(valid_p_raw_y, method="fdr_bh")[1] if valid_p_raw_y else []

    p_x_final = [np.nan] * len(p_raw_x)
    p_y_final = [np.nan] * len(p_raw_y)

    x_corr_idx = 0
    y_corr_idx = 0
    for i in range(1, len(p_raw_x)):
        if not np.isnan(p_raw_x[i]):
            p_x_final[i] = corrected_x[x_corr_idx]
            x_corr_idx += 1
        if not np.isnan(p_raw_y[i]):
            p_y_final[i] = corrected_y[y_corr_idx]
            y_corr_idx += 1

    # Figure setup (existing code)
    fig, ax = plt.subplots(figsize=(5.5, 3.5))
    x_pos = np.arange(len(decoders))
    w = 0.35

    bx = ax.bar(x_pos - w/2, means_x, width=w, yerr=ci_x, capsize=3,
                color=OKABE[0], edgecolor="black", label="X correlation")
    by = ax.bar(x_pos + w/2, means_y, width=w, yerr=ci_y, capsize=3,
                color=OKABE[1], edgecolor="black", label="Y correlation")

    ax.set_ylabel("Pearson $R$ (mean ± 95% CI)")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(decoders, rotation=45, ha="right")
    ax.set_ylim(bottom=0, top=max(1.0, max_data_value + y_axis_top_buffer))
    ax.legend(frameon=False, loc="upper left")

    tops_x = bar_tops(means_x, ci_x)
    tops_y = bar_tops(means_y, ci_y)
    add_stars_over_bars(ax, x_pos - w/2, tops_x, p_x_final, star_vertical_offset=0.01, color='black') # Original stars
    add_stars_over_bars(ax, x_pos + w/2, tops_y, p_y_final, star_vertical_offset=0.01, color='black') # Original stars

    # Add stars for the second baseline if specified
    if second_baseline_name:
        if second_baseline_name not in decoders:
            print(f"Warning: Second baseline '{second_baseline_name}' not found in decoders. Skipping second set of stars.")
        else:
            second_baseline_idx = decoders.index(second_baseline_name)
            # Calculate p-values for second baseline
            p_raw_x_second = [np.nan] * len(decoders)
            p_raw_y_second = [np.nan] * len(decoders)

            for i, current_decoder_x_data in enumerate(X_list):
                if i == second_baseline_idx:
                    continue
                p_raw_x_second[i] = ttest_ind(current_decoder_x_data, X_list[second_baseline_idx], equal_var=False).pvalue
                p_raw_y_second[i] = ttest_ind(Y_list[i], Y_list[second_baseline_idx], equal_var=False).pvalue

            valid_p_raw_x_second = [p for p in p_raw_x_second if not np.isnan(p)]
            valid_p_raw_y_second = [p for p in p_raw_y_second if not np.isnan(p)]

            corrected_x_second = multipletests(valid_p_raw_x_second, method="fdr_bh")[1] if valid_p_raw_x_second else []
            corrected_y_second = multipletests(valid_p_raw_y_second, method="fdr_bh")[1] if valid_p_raw_y_second else []

            # Map corrected p-values back to their original positions
            p_x_final_second = [np.nan] * len(decoders)
            p_y_final_second = [np.nan] * len(decoders)

            x_corr_idx_second = 0
            y_corr_idx_second = 0
            for i in range(len(decoders)):
                if i == second_baseline_idx:
                    continue
                if not np.isnan(p_raw_x_second[i]):
                    p_x_final_second[i] = corrected_x_second[x_corr_idx_second]
                    x_corr_idx_second += 1
                if not np.isnan(p_raw_y_second[i]):
                    p_y_final_second[i] = corrected_y_second[y_corr_idx_second]
                    y_corr_idx_second += 1

            # Add second set of stars, slightly higher
            add_stars_over_bars(ax, x_pos - w/2, tops_x, p_x_final_second, star_vertical_offset=0.04, color=second_star_color) # Increased offset
            add_stars_over_bars(ax, x_pos + w/2, tops_y, p_y_final_second, star_vertical_offset=0.04, color=second_star_color) # Increased offset

    if show_table:
        pass # Table drawing logic remains removed

    fig.tight_layout()
    fig.savefig(outfile, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)

# Data for Cell 1
true_online_x = [0.796852, 0.628121, 0.582131, 0.461622, 0.652083, 0.675291, 0.511745, 0.638775, 0.575411, 0.602889]
true_online_y = [0.649087, 0.764147, 0.799070, 0.619791, 0.754308, 0.706818, 0.736061, 0.705147, 0.714538, 0.689790]

windowed_x = [0.634120, 0.647772, 0.558672, 0.673100, 0.617063, 0.638310, 0.567807, 0.666757, 0.531551, 0.603142]
windowed_y = [0.755520, 0.807930, 0.655794, 0.797858, 0.680766, 0.629724, 0.602778, 0.791688, 0.803241, 0.719227]

bptt_x = [0.774077, 0.619250, 0.402692, 0.559531, 0.673314, 0.633008, 0.632742, 0.670352, 0.669159, 0.691254]
bptt_y = [0.787519, 0.763156, 0.619457, 0.699719, 0.787003, 0.733751, 0.751782, 0.763180, 0.704626, 0.635428]

kf_x = [0.586157, 0.377500, 0.468556, 0.336909, 0.373175, 0.397722, 0.375365, 0.393699, 0.403917, 0.442011]
kf_y = [0.522012, 0.476830, 0.506385, 0.371867, 0.445722, 0.524522, 0.489962, 0.447903, 0.277717, 0.494375]

lstm_x = [0.885787, 0.725801, 0.618703, 0.698620, 0.716149, 0.746294, 0.749629, 0.779819, 0.768514, 0.822811]
lstm_y = [0.855324, 0.856735, 0.804183, 0.835848, 0.848814, 0.836300, 0.860286, 0.875546, 0.809811, 0.771913]

decoders = ['True Online SNN', 'Batched Online SNN', 'Kalman Filter', 'LSTM', 'BPTT SNN']
X_list = [true_online_x, windowed_x, kf_x, lstm_x, bptt_x]
Y_list = [true_online_y, windowed_y, kf_y, lstm_y, bptt_y]

barplot_vs_baseline(decoders, X_list, Y_list, baseline=0,
                    outfile="figures/decoder_comparison_cell1.pdf",
                    show_table=False, second_baseline_name='Batched Online SNN', second_star_color='red') # Passed 'red' explicitly

# %%
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests

# You might need to re-run the first cell to have the functions defined in this cell

# Add Online SNN (windowed) to the lists and re-plot
lstm_x = [0.952153, 0.956655, 0.953523, 0.956431, 0.956649, 0.954146, 0.958331, 0.949800, 0.952872, 0.957016]
lstm_y = [0.933290, 0.929391, 0.929139, 0.936171, 0.924867, 0.928272, 0.936700, 0.929923, 0.936982, 0.932154]

snn_x = [0.8285, 0.8232, 0.8339, 0.8449, 0.8320, 0.8452, 0.8510, 0.8266, 0.8357, 0.8452]
snn_y = [0.7967, 0.8004, 0.7869, 0.8180, 0.7853, 0.8039, 0.8124, 0.7994, 0.8057, 0.8170]

kf_x = [0.810412, 0.828624, 0.823223, 0.821020, 0.817931, 0.816748, 0.810702, 0.831634, 0.818697, 0.815335]
kf_y = [0.720182, 0.701546, 0.728560, 0.713928, 0.715097, 0.701409, 0.698864, 0.727354, 0.713970, 0.715614]

bptt_x = [0.8828, 0.8884, 0.8808, 0.8760, 0.8870, 0.8853, 0.8749, 0.8861, 0.8864, 0.8846]
bptt_y = [0.8415, 0.8434, 0.8531, 0.8398, 0.8358, 0.8371, 0.8404, 0.8564, 0.8326, 0.8381]

snn_windowed_x = [0.848747, 0.860945, 0.885313, 0.890126, 0.873845, 0.862622, 0.870642, 0.893943, 0.869261, 0.842304]
snn_windowed_y = [0.772264, 0.768282, 0.832796, 0.832304, 0.787622, 0.757756, 0.781899, 0.817539, 0.788946, 0.767053]


decoders = ['True Online SNN', 'Batched Online SNN', 'Kalman Filter', 'LSTM', 'BPTT SNN']
X_list = [snn_x, snn_windowed_x, kf_x, lstm_x, bptt_x]
Y_list = [snn_y, snn_windowed_y, kf_y, lstm_y, bptt_y]

barplot_vs_baseline(decoders, X_list, Y_list, baseline=0,
                    outfile="figures/decoder_comparison_cell2.pdf",
                    show_table=False, second_baseline_name='Batched Online SNN', second_star_color='red') # Passed 'red' explicitly


