import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec, ticker
from collections import defaultdict
from scipy.stats import wilcoxon
from scipy.io import loadmat

def _counts_per_point(x, y):
    x = np.asarray(x)
    y = np.asarray(y)
    keys = [ (int(round(xi)), int(round(yi))) for xi, yi in zip(x, y) ]
    cnt = defaultdict(int)
    for k in keys:
        cnt[k] += 1
    return np.array([ cnt[k] for k in keys ], dtype=int)

def plot_two_side_by_side_sharedy(t_sfrs, t_accs,
                                  save='./time_lag/time-lag.png',
                                  cmap='plasma',
                                  marker_size=10,
                                  fs=15,
                                  line_kwargs={'color':'green', 'linestyle':'--', 'linewidth':0.8},
                                  figsize=(12,6),
                                  dpi=300):
    """
    两图并列 (left, right) 并共享 y 轴；colorbar 放最右侧（单独一个）。
    - t_sfrs, t_accs: each is [group0_array, group1_array] (长度为2)
    - save: 输出合成图路径（只保存一份）
    """
    # ensure dir exists
    d = os.path.dirname(save) or '.'
    os.makedirs(d, exist_ok=True)

    # convert to arrays
    x0 = np.asarray(t_sfrs[0]).ravel(); y0 = np.asarray(t_accs[0]).ravel()
    x1 = np.asarray(t_sfrs[1]).ravel(); y1 = np.asarray(t_accs[1]).ravel()

    # compute per-point counts (within each group)
    counts0 = _counts_per_point(x0, y0)
    counts1 = _counts_per_point(x1, y1)

    # shared colorbar range
    vmin = min(int(counts0.min()), int(counts1.min()))
    vmax = max(int(counts0.max()), int(counts1.max()))
    if vmax == vmin:
        vmax = vmin + 1

    # compute wilcoxon p-values as requested (safe fallback to None)
    try:
        w0, p0 = wilcoxon(y0, x0, alternative='greater')
    except Exception:
        p0 = None
    try:
        w1, p1 = wilcoxon(y1, x1, alternative='greater')
    except Exception:
        p1 = None

    # figure + gridspec: left plot, right plot, colorbar (rightmost)
    fig = plt.figure(figsize=figsize)
    # width ratios: left plot, right plot, colorbar (narrow)
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.07], wspace=0.18)

    ax0 = fig.add_subplot(gs[0,0])
    ax1 = fig.add_subplot(gs[0,1], sharey=ax0)  # share y-axis with ax0
    cax = fig.add_subplot(gs[0,2])  # colorbar axis (narrow)

    # fixed plotting range and arrays for filled regions / line
    xmin, xmax = 8, 64
    ymin_val, ymax_val = 8, 64
    xline = np.linspace(xmin, xmax, 500)

    # scatter (uses counts as color)
    sc0 = ax0.scatter(x0, y0, c=counts0, cmap=cmap, s=marker_size,
                      vmin=vmin, vmax=vmax, edgecolor='none', zorder=3)
    sc1 = ax1.scatter(x1, y1, c=counts1, cmap=cmap, s=marker_size,
                      vmin=vmin, vmax=vmax, edgecolor='none', zorder=3)

    # y = x lines
    ax0.plot(xline, xline, **line_kwargs, zorder=2)
    ax1.plot(xline, xline, **line_kwargs, zorder=2)

    # filled regions (green above, darkgray below)
    for ax in (ax0, ax1):
        ax.fill_between(xline, xline, ymax_val,
                        where=(ymax_val >= xline),
                        facecolor='green', alpha=0.18, interpolate=True, zorder=1)
        ax.fill_between(xline, ymin_val, xline,
                        where=(xline >= ymin_val),
                        facecolor='darkgray', alpha=0.36, interpolate=True, zorder=1)

    # equal aspect & fixed limits
    ax0.set_aspect('equal', adjustable='box')
    ax1.set_aspect('equal', adjustable='box')
    ax0.set_xlim(xmin, xmax); ax0.set_ylim(ymin_val, ymax_val)
    ax1.set_xlim(xmin, xmax); ax1.set_ylim(ymin_val, ymax_val)

    # Force ticks to be exactly [8,16,...,64]
    forced_ticks = list(range(8, 65, 8))
    ax0.set_xticks(forced_ticks)
    ax0.set_yticks(forced_ticks)
    ax1.set_xticks(forced_ticks)
    # ax1 shares y with ax0, so no need to set its yticks separately (but safe to set same)
    ax1.set_yticks(forced_ticks)
    ax0.tick_params(axis='both', labelsize=fs)
    ax1.tick_params(axis='both', labelsize=fs)

    # Titles & labels; share y-axis label on left only
    ax0.set_title('dense', fontsize=fs)
    ax1.set_title('sparse', fontsize=fs)
    ax0.set_xlabel('MASFR saturation time', fontsize=fs)
    ax1.set_xlabel('MASFR saturation time', fontsize=fs)
    ax0.set_ylabel('accuracy saturation time', fontsize=fs)
    # ensure ax1 does not show a duplicate y-label (it shares ax0)
    ax1.set_ylabel('')

    # p-values in right-bottom corner with bordered box (if available)
    if p0 is not None:
        ax0.text(0.93, 0.05, f"p = {p0:.3e}",
                 transform=ax0.transAxes,
                 fontsize=fs,
                 verticalalignment='bottom',
                 horizontalalignment='right',
                 bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=0.95),
                 zorder=4)
    if p1 is not None:
        ax1.text(0.93, 0.05, f"p = {p1:.3e}",
                 transform=ax1.transAxes,
                 fontsize=fs,
                 verticalalignment='bottom',
                 horizontalalignment='right',
                 bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=0.95),
                 zorder=4)

    # Shared colorbar on the rightmost axis
    cb = fig.colorbar(sc0, cax=cax, orientation='vertical')
    cb.set_label('overlap count', rotation=270, labelpad=15, fontsize=fs)

    # colorbar ticks: integer-spaced or representative
    if vmax - vmin <= 10:
        cb_ticks = list(range(vmin, vmax+1))
    else:
        cb_ticks = list(np.round(np.linspace(vmin, vmax, num=6)).astype(int))
        cb_ticks = sorted(list(set(cb_ticks)))
    cb.set_ticks(cb_ticks)
    cax.tick_params(labelsize=fs)
    cax.yaxis.set_major_locator(ticker.FixedLocator(cb_ticks))

    # adjust layout and save
    plt.tight_layout()
    fig.savefig(save, dpi=dpi, bbox_inches='tight')
    plt.close(fig)


def violin(data, p, fs):
    fig, ax = plt.subplots(figsize=(6,6))

    # 画小提琴图
    parts = ax.violinplot(data, showmeans=True, showmedians=False, showextrema=True)

    if p is not None:
        ax.text(0.93, 0.8, f"p = {p:.3e}",
                transform=ax.transAxes,
                fontsize=fs,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=0.95),
                zorder=4)

    ax.set_ylabel('absolute time lag difference', fontsize=fs)  # y 轴标签

    #ax.set_title('Time lag absolute difference', fontsize=fs)

    ax.tick_params(labelsize=fs)

    plt.tight_layout()

    # 保存图
    fig.savefig('./time_lag/violin_time_lag_diff.png', dpi=300)
    plt.close(fig)


def plot_histogram_lag_diff(lag_diff, p,
                             bins=30,
                             color='skyblue',
                             edgecolor='black',
                             xlabel='time lag difference',
                             ylabel='Count',
                             title='Distribution of time lag difference',
                             savepath='./time_lag/hist_lag_diff.png',
                             dpi=300):
    """
    画 lag_diff = (t_accs[1] - t_sfrs[1]) − (t_accs[0] - t_sfrs[0]) 的直方图，并标注平均值。
    参数：
     - t_sfrs, t_accs: 各自长度‐2的 list/array，group0 和 group1
     - bins: 直方图分箱数
     - color, edgecolor: 柱子的填充色和边缘色
     - xlabel, ylabel, title: 标签与标题
     - savepath: 保存路径
     - dpi: 保存图像的分辨率
    """
    # 计算平均值
    mean_ld = np.mean(lag_diff)

    # 绘图
    fig, ax = plt.subplots(figsize=(6,4))
    ax.hist(lag_diff, bins=bins, color=color, edgecolor=edgecolor)

    # 标注平均值（竖线 + 文本）
    ax.axvline(mean_ld, color='red', linestyle='--', linewidth=1.5, label=f'Mean = {mean_ld:.2f}')
    # 文本放在图的右上角
    ax.text(0.95, 0.95, f"Mean = {mean_ld:.2f}", transform=ax.transAxes,
            ha='right', va='top', fontsize=10, color='red',
            bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=0.9))
    
    if p is not None:
        ax.text(0.93, 0.8, f"p = {p:.3e}",
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=0.95),
                zorder=4)


    # 标签 & 标题
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    # 保存图，一个文件
    os.makedirs(os.path.dirname(savepath) or '.', exist_ok=True)
    fig.savefig(savepath, dpi=dpi, bbox_inches='tight')
    plt.close(fig)

if __name__ == '__main__':
    mat = loadmat('./t1_t2.mat')
    t_accs = mat['t_accs']
    t_sfrs = mat['t_sfrs']

    colors = ['black', 'red']

    time_lag0 = t_accs[0] - t_sfrs[0]
    time_lag1 = t_accs[1] - t_sfrs[1]
    lag_diff = time_lag1 - time_lag0

    w, p_diff = wilcoxon(time_lag1, time_lag0)
    print(p_diff)

    plot_two_side_by_side_sharedy(t_sfrs, t_accs, save='./time_lag/time-lag.png',fs=18)
    violin(abs(lag_diff), p_diff, 18)
    #plot_histogram_lag_diff(abs(lag_diff), p_diff)
