import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import gaussian_filter1d
from statsmodels.stats.multitest import multipletests


n_time_points=200
n_permutations=10


def plot_channels_grid_fdr(channels_range, figure_title, alpha_level=0.05, 
                           time_windows_corrs_array=None,
                           time_windows_null_dist_array=None, prefix_title="", out_folder=""):
    """
    1) Compute p-values for each channel/time point in channels_range.
    2) Perform FDR correction across all these channel/timepoint comparisons.
    3) Plot an (m x m) grid (here 4x4) showing each channel’s data plus significance markers.
    """
    # Prepare the figure (adjust rows/cols to match how many channels you want per figure)
    n_plots = len(channels_range)  # e.g., 16 => 4x4
    grid_size = int(np.ceil(np.sqrt(n_plots)))  # choose square or fix as 4x4
    fig, axes = plt.subplots(4, 4, figsize=(20, 16), dpi=150)
    fig.suptitle(figure_title, fontsize=14, fontweight="bold")

    # 1) Compute all p-values for the selected channels
    #    We'll store them in an array of shape (n_plots, n_time_points).
    pvals_all = []
    for ch_idx in channels_range:
        # Real data
        real_values = time_windows_corrs_array[:, ch_idx]  # shape: (n_time_points,)
        # Null data for this channel
        null_values = time_windows_null_dist_array[:, :, ch_idx]  # shape: (n_permutations, n_time_points)
        
        # Count how many permutations exceed the real value => one-tailed test
        greater_counts = (null_values >= real_values[:, np.newaxis]).sum(axis=1)  # shape: (n_time_points,)
        pvals = greater_counts / n_permutations  # shape: (n_time_points,)
        
        pvals_all.append(pvals)

    pvals_all = np.array(pvals_all)  # shape: (n_plots, n_time_points)

    # 2) Flatten p-values, apply FDR
    pvals_flat = pvals_all.ravel()  # shape: (n_plots * n_time_points,)
    reject_flags, pvals_corrected, _, _ = multipletests(pvals_flat, alpha=alpha_level, method='fdr_bh')
    
    # Reshape the reject flags so we know which channel/timepoint is significant
    reject_2d = reject_flags.reshape(n_plots, n_time_points)

    # 3) Plot each channel
    for i, ch_idx in enumerate(channels_range):
        row = i // 4
        col = i % 4
        ax = axes[row, col]

        # Get real correlation
        raw_signal = time_windows_corrs_array[:, ch_idx]
        # Optionally, you can smooth the real signal:
        smoothed_signal = gaussian_filter1d(raw_signal, sigma=2)
        # smoothed_signal = raw_signal  # or above line for smoothing

        # Plot raw (noisy) with alpha
        ax.plot(raw_signal, color="tab:orange", alpha=0.4,
                label="Encoding (raw)" if i == 0 else "")

        # Plot smoothed
        ax.plot(smoothed_signal, color="tab:orange", linewidth=2,
                label="Encoding (smoothed)" if i == 0 else "")

        # Plot null mean ± std
        null_vals_ch = time_windows_null_dist_array[:, :, ch_idx]
        null_mean = null_vals_ch.mean(axis=1)
        null_std = null_vals_ch.std(axis=1)
        ax.plot(null_mean, color="tab:blue", label="Null mean" if i == 0 else "")
        ax.fill_between(range(n_time_points),
                        null_mean - null_std,
                        null_mean + null_std,
                        color="tab:blue", alpha=0.2,
                        label="Null ±1 std" if i == 0 else "")

        # 4) Mark significance after FDR
        sig_indices = np.where(reject_2d[i])[0]  # time points for channel i that are significant
        for t in sig_indices:
            ax.plot(t, 0.28, marker="*", color="black", markersize=3)

        ax.axvline(x=50, color='r', linestyle='--', linewidth=1, alpha=0.7)

        ax.set_ylim(-0.3, 0.3)
        ax.set_title(f"Ch {ch_idx}", fontsize=9)

        # Example minimal ticks
        sns.despine(ax=ax)
        ax.set_xticks([0, 50, 100, 150, 199])
        ax.set_yticks([-0.3, 0, 0.3])

    # Make unused subplots blank (if your channels_range < 16)
    # Or you can compute # needed subplots dynamically.
    for j in range(n_plots, 16):
        row = j // 4
        col = j % 4
        axes[row, col].axis('off')

    # Single legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper right")

    plt.tight_layout()
    #save fig
    plt.savefig(f"{out_folder}/{prefix_title}_channels_grid_fdr_channels_{channels_range.start}_{channels_range.stop}.png", dpi=300)
    plt.show()
