import matplotlib
matplotlib.use('Agg', force=True)
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import os
from adjustText import adjust_text
import pandas as pd
import sys
sys.path.append("../")
from typing import Dict
from latency import find_latency

colors=['black','red']
default_fs = plt.rcParams['font.size']

def snn_t_12(ax, sparsity_range , xs_all, snn_data, baseline, max_T=64, title=None, fs_scaler=3):
    """
    ax: matplotlib Axes 对象，由主函数传入
    points: (np.array, Dict)  -> (xs, snn_data)
    baseline: Dict
    max_T: 最大 T，用于截断或设定 x 范围. -1代表全部
    其余与原来参数相同
    """
    fs = fs_scaler * default_fs

    # 如果 max_T 有效则截断 xs_all 的视图（不改动 xs_all 本身）
    if max_T is not None and max_T > 0:
        xs_view = [x for x in xs_all if x <= max_T]
    else:
        xs_view = list(xs_all)

    texts_to_adjust = []   # 将送去 adjust_text 的 text objects
    add_objects = []       # 将作为障碍物传给 adjust_text（线、点等）

    maxs = []              # 用于计算 y 上限
    for i, sparsity in enumerate(sparsity_range):
        label_plot = 'dense SNN' if sparsity == 0.0 else 'sparse SNN'
        label_hline = 'dense input ANN with ReLU' if sparsity == 0.0 else 'sparse input ANN with ReLU'
        color = colors[i]

        # snn 数据并截断到 xs_view 长度
        snn = (snn_data[sparsity].flatten() * 100)
        snn_plot = snn[:len(xs_view)]

        # 画曲线
        line_curve, = ax.plot(xs_view, snn_plot, color=color, label=label_plot, zorder=5, linewidth=3)
        add_objects.append(line_curve)

        # 画 baseline 的虚线（axhline 返回 Line2D）
        baseline_h = baseline[sparsity] * 100
        line_base = ax.axhline(y=baseline_h, color=color, linestyle='--',
                               linewidth=2, label=label_hline, zorder=4)
        add_objects.append(line_base)

        # --- 修改虚线 dash 为原来两倍长度（dash list * 2） ---
        try:
            offset, dashseq = line_base.get_dashes()  # (offset, [dash, gap, dash, gap...])
            if not dashseq:
                dashseq = [6.0, 6.0]
            dashseq2 = [d * 2.0 for d in dashseq]
            line_base.set_dashes(dashseq2)
        except Exception:
            line_base.set_dashes([12.0, 12.0])

        # --- (3) 在靠近 y 轴处写出虚线高度（保留两位小数），字体大小 fs ---
        # 取一个靠近 y 轴的 x 值（使用 xs_view 的起点 + 小偏移）
        x_left_label = -15
        t_left = ax.text(x_left_label, baseline_h, f'{baseline_h:.2f}',
                         va='center', ha='left', color=color,
                         fontsize=fs, bbox=dict(facecolor='white', edgecolor='none', pad=0),
                         zorder=0)  # 放到底层
        texts_to_adjust.append(t_left)

        # --- (4) 标出曲线的latency（并用彩色 marker 标出） ---
        idx_max = int(np.argmax(snn_plot))
        y_max = float(snn_plot[idx_max])
        latency, idx = find_latency(xs_view, snn_plot, window=10)
        assert idx
        y_lat = float(snn_plot[idx])
        pt_marker, = ax.plot([latency], [y_lat], marker='o', markersize=15, linestyle='',
                            color=color, zorder=6)
        add_objects.append(pt_marker)
    

        # --- (5) 在全图右侧标记两条曲线的最大值（与曲线颜色一致） ---
        x_right = xs_view[-1] + 2
        t_right = ax.text(x_right, y_max, f'{y_max:.2f}',
                          va='center', ha='left', color=color,
                          fontsize=fs, zorder=0, bbox=dict(facecolor='white', edgecolor='none', pad=0))
        texts_to_adjust.append(t_right)

        # 收藏 max 值和 baseline 用于后续 ylim 计算
        maxs.append(np.max(snn_plot))
        maxs.append(baseline_h)

    # 设置 x,y 范围（注意 xs_view 最后一个点）
    x_right_limit = xs_view[-1] + 16
    ax.set_xlim(-15, x_right_limit)
    ax.set_ylim(top=(int(max(maxs) / 10) + 1) * 10)

    # --- 坐标轴刻度 / 格式化（保留你原来的设置） ---
    major_step = 10
    major_ticks = np.arange(0, max_T + 1, major_step)

    if len(major_ticks) >= 2:
        minor_ticks = major_ticks[:-1] + major_step / 2.0
    else:
        minor_ticks = np.array([])

    # 用 FixedLocator 固定主/次刻度的位置（避免在 xlim 之外出现刻度）
    ax.xaxis.set_major_locator(ticker.FixedLocator(major_ticks))
    ax.xaxis.set_minor_locator(ticker.FixedLocator(minor_ticks))

    # 强制刻度标签（主刻度）从 0 开始（负数不会出现）
    ax.set_xticks(major_ticks)           # 可选，因为 FixedLocator 已经生效
    ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
    ax.xaxis.set_minor_formatter(ticker.NullFormatter())

    # --- Y 轴：保留你原来的设定（如果你仍然希望用 MultipleLocator / AutoMinorLocator）---
    ax.yaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())

    # Tick 外观（保留你原来的设置）
    ax.tick_params(axis='y', which='major', length=12, labelsize=fs)
    ax.tick_params(axis='y', which='minor', length=6, labelsize=0)
    ax.tick_params(axis='x', which='major', length=12, labelsize=fs)
    ax.tick_params(axis='x', which='minor', length=6)

    ax.grid(False)

    ax.set_xlabel("Inference time steps", fontsize=fs)
    ax.set_title(title, fontsize=fs)

    # --- 使用 adjustText 只在 y 方向移动文本，并避免与已有对象碰撞 ---
    # 将所有需要被考虑的对象（线、点）传入 add_objects
    # only_move 尽量把所有可能的 key 都限制为 'y'
    only_move = {'text': 'y', 'points': 'y', 'objects': 'y',
                    'static': 'y', 'explode': 'y', 'pull': 'y'}
    adjust_text(texts_to_adjust,
                add_objects=add_objects,
                only_move=only_move,
                autoalign='y',
                expand_text=(1.02, 1.02),
                expand_points=(1.02, 1.02),
                ax=ax)
    # 确保所有文本在最下图层（zorder 非常小）
    for tt in texts_to_adjust:
        tt.set_zorder(0)


#如果是cnn，会很显杂乱，或许可以放进附录，然后展示MLP的. cnn展示bar？
def LASFR_t(
    ax: plt.Axes,
    title: str,
    points: Dict[float, np.ndarray],
    sparsity_range,
    max_T=64):

    """
    在给定的 ax 上绘制 LASFR 曲线（不读取文件）。
    参数:
      ax: matplotlib Axes（由主调创建并传入）
      title: 标题字符串（由主调提供）
      points: (np.array, dict) dict: key=sparsity (float) -> value = np.ndarray shape (n_rows, n_cols)
                  每个 value 的列数为时间步长度（不一定是64）
      sparsity_range: 列表/可迭代稀疏度顺序（用于配色顺序）
      get_color_fn: 函数 (i, j, n) -> color
      default_fs: 基准字体大小（例如 default_fs）
    返回:
      texts: list of matplotlib.text.Text 对象（主调用于传入 adjust_text）
    """
    fs = 1.5 * default_fs
    texts = []
    maxs = []
    xs, lasfr_data = points

    # 绘每个 sparsity 对应的 LASFR 矩阵
    for i, sparsity in enumerate(sparsity_range):
        arr = lasfr_data.get(sparsity, None)
        if arr is None:
            # 如果没有数据，就跳过
            continue
        LASFR = np.asarray(arr)
        if LASFR.ndim != 2:
            raise ValueError(f"LASFR for sparsity {sparsity} must be 2D array (n_rows, n_cols)")

        n_rows, n_cols = LASFR.shape

        for j in range(n_rows):
            color = get_color(i, j, n_rows)
            y = LASFR[j, :max_T]

            ax.plot(xs, y, color=color, label=None)
            #last = float(y[-1])

            # 在最右侧（最后一个 x）处写值
            #x_text = xs[-1]
            #txt = ax.text(x_text, last, f'{last:.4f}',
                          #color=color, ha='left', va='center', fontsize=fs)
            #texts.append(txt)
            #maxs.append(last)


    ax.yaxis.set_major_locator(ticker.AutoLocator())
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax.tick_params(axis='y', which='major', length=6, labelsize=fs)
    ax.tick_params(axis='y', which='minor', length=3)
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())

    # x 轴：只标到最后（但留白）
    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=8, integer=True))
    ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax.tick_params(axis='x', which='major', length=6, labelsize=fs)
    ax.tick_params(axis='x', which='minor', length=3)
    ax.xaxis.set_minor_formatter(ticker.NullFormatter())

    # xlim 保留一定留白（右侧多 20% 或固定 +20）
    ax.set_xlim(xmax=len(y)+16)
    ax.set_ylim(top=(int(max(maxs) * 100) + 1) / 100)

    ax.set_xlabel("Time steps", fontsize=fs)
    ax.set_title(title, fontsize=fs)
    ax.legend(fontsize=fs)  # 如果你不希望 legend，可删除或由主调控制
    ax.grid(True)

    return texts


def LASFR_bar(
    ax: plt.Axes,
    title: str,
    lasfr_data: dict,
    sparsity_range: list,
    time_index: int = 32,
    layer_labels: list = None
):
    """
    在给定 ax 上画组柱状图，横轴为 j（layer index 等），纵轴为 LASFR 在指定 time_index 的值。

    参数:
      ax: matplotlib Axes 对象，由主调创建并传入
      title: 图表标题
      lasfr_data: dict，键是稀疏度值（例如 0.0, 0.5 等），值是 LASFR 数组，形状 (n_rows, n_cols)，n_cols >= time_index+1
      sparsity_range: list 或 可迭代的稀疏度顺序（保持一致用于配色顺序与 legend 标记）
      time_index: int，选择 LASFR 的哪一列（用零基 indexing），例如原来 time=63 意味着第 64 列的值 => time_index=63
      layer_labels: 可选，用于在 x 轴上显示每个 j 的 label；如果 None，就用数字 “1,2,3,...”

    返回:
      None（你也可以让它返回绘图的一些艺术家 list，如果需要以后调节）
    """
    fs = 1.5 * default_fs

    # 准备 ys: shape (m, n_rows)，m = len(sparsity_range)
    ys_list = []  # list of arrays
    n_rows_list = []
    for sparsity in sparsity_range:
        arr = lasfr_data.get(sparsity)
        if arr is None:
            raise ValueError(f"lasfr_data 中不包含 sparsity={sparsity}")
        LASFR = np.asarray(arr)
        if LASFR.ndim != 2:
            raise ValueError(f"LASFR for sparsity {sparsity} must be a 2D array, got shape {LASFR.shape}")
        n_rows, n_cols = LASFR.shape
        if time_index >= n_cols:
            raise IndexError(f"time_index={time_index} >= number of columns {n_cols} for sparsity {sparsity}")
        # 提取每行在 time_index 的值
        y = LASFR[:, time_index].astype(float)
        ys_list.append(y)
        n_rows_list.append(n_rows)

    # 检查是否所有 sparsity 的 LASFR 行数相同，否则绘图会错位
    if not all(n == n_rows_list[0] for n in n_rows_list):
        raise ValueError(f"不同稀疏度的 LASFR 行数不一致：{n_rows_list}")
    n = n_rows_list[0]  # 行数

    ys = np.stack(ys_list, axis=0)  # shape (m, n)

    m = ys.shape[0]  # number of sparsity variants
    # 绘组柱状图参数
    bar_width = 0.35
    gap = 0.25  # 组间间隔
    group_width = m * bar_width + gap
    group_centers = np.arange(n) * group_width

    # 绘制每一稀疏度的一组柱子
    for i in range(m):
        # i 对应 sparsity_index = i
        sparsity = sparsity_range[i]
        y_vals = ys[i, :]
        x_positions = group_centers + i * bar_width
        # 颜色列表
        colors = [get_color(i, j, n) for j in range(n)]
        # 标签
        label = 'dense SNN' if sparsity == 0.0 else f"sparse SNN"
        ax.bar(x_positions, y_vals, width=bar_width, label=label, color=colors, align='edge')

    # x 轴刻度
    tick_positions = group_centers + ((m - 1) * bar_width) / 2.0
    if layer_labels is None:
        tick_labels = [str(j + 1) for j in range(n)]
    else:
        if len(layer_labels) != n:
            raise ValueError("layer_labels 长度必须 == LASFR 行数 (n)")
        tick_labels = layer_labels

    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels, fontsize=fs * 0.8)

    # 标题和轴标签
    ax.set_title(title, fontsize=fs)
    ax.set_xlabel('layer index', fontsize=fs * 0.9)
    ax.set_ylabel(f'LASFR (t={time_index+1})', fontsize=fs * 0.9)
    ax.legend(fontsize=fs * 0.8)
    ax.grid(axis='y', linestyle='--', alpha=0.4)



def snn_t_3(ax, sparsity_range , xs_all, snn_data, baseline, max_T=64, title=None, fs_scaler=3):
    """
    ax: matplotlib Axes 对象，由主函数传入
    points: (np.array, Dict)  -> (xs, snn_data)
    baseline: Dict
    max_T: 最大 T，用于截断或设定 x 范围. -1代表全部
    其余与原来参数相同
    """
    fs = fs_scaler * default_fs

    # 如果 max_T 有效则截断 xs_all 的视图（不改动 xs_all 本身）
    if max_T is not None and max_T > 0:
        xs_view = [x for x in xs_all if x <= max_T]
    else:
        xs_view = list(xs_all)

    texts_to_adjust = []   # 将送去 adjust_text 的 text objects
    add_objects = []       # 将作为障碍物传给 adjust_text（线、点等）

    maxs = []              # 用于计算 y 上限
    for i, sparsity in enumerate(sparsity_range):
        label_plot = 'dense SNN' if sparsity == 0.0 else 'sparse SNN'
        label_hline = 'dense input ANN with ReLU' if sparsity == 0.0 else 'sparse input ANN with ReLU'
        color = colors[i]

        # snn 数据并截断到 xs_view 长度
        snn = (snn_data[sparsity].flatten() * 100)
        snn_plot = snn[:len(xs_view)]

        # 画曲线
        line_curve, = ax.plot(xs_view, snn_plot, color=color, marker='o', markersize=8, label=label_plot, zorder=5, linewidth=3)
        add_objects.append(line_curve)

        # 画 baseline 的虚线（axhline 返回 Line2D）
        baseline_h = baseline[sparsity] * 100
        line_base = ax.axhline(y=baseline_h, color=color, linestyle='--',
                               linewidth=2, label=label_hline, zorder=4)
        add_objects.append(line_base)

        # --- 修改虚线 dash 为原来两倍长度（dash list * 2） ---
        try:
            offset, dashseq = line_base.get_dashes()  # (offset, [dash, gap, dash, gap...])
            if not dashseq:
                dashseq = [6.0, 6.0]
            dashseq2 = [d * 2.0 for d in dashseq]
            line_base.set_dashes(dashseq2)
        except Exception:
            line_base.set_dashes([12.0, 12.0])

        # --- (3) 在靠近 y 轴处写出虚线高度（保留两位小数），字体大小 fs ---
        # 取一个靠近 y 轴的 x 值（使用 xs_view 的起点 + 小偏移）
        x_left_label = -15
        t_left = ax.text(x_left_label, baseline_h, f'{baseline_h:.2f}',
                         va='center', ha='left', color=color,
                         fontsize=fs, bbox=dict(facecolor='white', edgecolor='none', pad=0),
                         zorder=0)  # 放到底层
        texts_to_adjust.append(t_left)

        # --- (4) 标出曲线的最大值（并用彩色 marker 标出） ---
        idx_max = int(np.argmax(snn_plot))
        y_max = float(snn_plot[idx_max])
        x_m = xs_view[idx_max]
        pt_marker, = ax.plot([x_m], [y_max], marker='o', markersize=15, linestyle='',
                            color=color, zorder=6)
        add_objects.append(pt_marker)

        # --- (5) 在全图右侧标记两条曲线的最大值（与曲线颜色一致） ---
        x_right = xs_view[-1] + 2
        t_right = ax.text(x_right, y_max, f'{y_max:.2f}',
                          va='center', ha='left', color=color,
                          fontsize=fs, zorder=0, bbox=dict(facecolor='white', edgecolor='none', pad=0))
        texts_to_adjust.append(t_right)

        # 收藏 max 值和 baseline 用于后续 ylim 计算
        maxs.append(np.max(snn_plot))
        maxs.append(baseline_h)

    # 设置 x,y 范围（注意 xs_view 最后一个点）
    x_right_limit = xs_view[-1] + 16
    ax.set_xlim(-15, x_right_limit)
    ax.set_ylim(top=(int(max(maxs) / 10) + 1) * 10)

    ax.xaxis.set_major_locator(ticker.FixedLocator(xs_view))  
    ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
    ax.xaxis.set_minor_locator(ticker.NullLocator())          
    ax.set_xticks(xs_view)                                   

    # --- Y 轴：保留你原来的设定（如果你仍然希望用 MultipleLocator / AutoMinorLocator）---
    ax.yaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))
    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())

    # Tick 外观（保留你原来的设置）
    ax.tick_params(axis='y', which='major', length=12, labelsize=fs)
    ax.tick_params(axis='y', which='minor', length=6, labelsize=0)
    ax.tick_params(axis='x', which='major', length=12, labelsize=fs)

    ax.grid(False)

    ax.set_xlabel("Inference time steps", fontsize=fs)
    ax.set_title(title, fontsize=fs)

    # --- 使用 adjustText 只在 y 方向移动文本，并避免与已有对象碰撞 ---
    # 将所有需要被考虑的对象（线、点）传入 add_objects
    # only_move 尽量把所有可能的 key 都限制为 'y'
    only_move = {'text': 'y', 'points': 'y', 'objects': 'y',
                    'static': 'y', 'explode': 'y', 'pull': 'y'}
    adjust_text(texts_to_adjust,
                add_objects=add_objects,
                only_move=only_move,
                autoalign='y',
                expand_text=(1.02, 1.02),
                expand_points=(1.02, 1.02),
                ax=ax)
    # 确保所有文本在最下图层（zorder 非常小）
    for tt in texts_to_adjust:
        tt.set_zorder(0)






#--------------------------------------辅助函数: 检查,标签,数据,颜色--------------------------------------#
#除了get_color都废弃了，但是可以供integrate参考。
def get_color(sparse:bool,i:int, n:int, min_gray: float = 0.8,min_orange: tuple=(1.0, 0.5, 0.0)):
    if not sparse:
        t = i / (n - 1)
        gray = t * min_gray  # 从 0（纯黑）到 min_gray（较浅灰）
        return f"{gray:.3f}"
    else:
        t = i / (n - 1)
        r = 1.0 * (1 - t) + min_orange[0] * t #min_orange是典型的橙色
        g = 0.0 * (1 - t) + min_orange[1] * t
        b = 0.0 * (1 - t) + min_orange[2] * t
        return (r, g, b)

def check_qcfs_smallerthan_relu(acc_l,baseline):
    for i in range(len(acc_l)):
        if acc_l[i]>baseline:
            print(f'L={i+1}, accuracy with qcfs exceed relu.')


'''
def ANN_L(): #input
    plt.figure()
    title="ANN accuracy (with CS-QCFS) - L"
    texts=[]
    for i, sparsity in enumerate(sparsity_range):
        label_plot='dense ANN with CS-QCFS' if sparsity==0.0 else f"sparse({sparsity}) ANN with CS-QCFS"
        label_hline='dense input ANN with ReLU' if sparsity==0.0 else f"sparse({sparsity}) input ANN with ReLU"
        path=os.path.join(basic_input_path,f"s_{sparsity}","res.mat")
        acc = loadmat(path)['acc-L'].flatten()*100
        color=colors[i]
        plt.plot(range(1, len(acc) + 1), acc, color=color,label= label_plot)
        
        # 添加水平虚线在右侧标注数值
        baseline_h=baseline[i]
        check_qcfs_smallerthan_relu(acc,baseline_h)
        plt.axhline(y= baseline_h, color=color, linestyle='--', linewidth=1,label=label_hline)
        xmax =len(acc)
        t=plt.text(xmax,baseline_h, f'{baseline_h:.2f}', va='center', ha='left', color=color,
                bbox=dict(facecolor='white', edgecolor='none', pad=0))
        texts.append(t)
    adjust_text(texts, only_move={'text':'y'})
    
    plt.xlabel("L")
    plt.ylabel("ANN Accuracy (%)")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path,f'{title}.png'))
    plt.close()

def ft_acc(sparsity):
    labels = ['Best acc', 'Best AAE', 'Best snn_acc']
    colors = ['r', 'b', 'k']
    title= 'Test accuracy during training('+('dense' if sparsity==0.0 else f"sparse({sparsity})") + ')'
    plt.figure()

    for i, key in enumerate(['max_acc', 'AAE', 'snn_max']):
        exp = best_experiments[sparsity][key]
        data=loadmat(filename(*exp))['ft_acc']
        ft = list((data*100).flatten())
        plt.plot(range(len(ft)), ft, color=colors[i], label=labels[i]+' '+label(*(exp[1:])))

    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title(title)
    plt.legend()
    plt.grid(False)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path,f'{title}.png'))
    plt.close()


def LASFR_bar(args):
    labels=[]
    values=[]
    colors=[]
    for i, sparsity in enumerate(sparsity_range):
        exp = best_experiments[sparsity]['snn_max']
        LASFR = loadmat(filename(*exp))['LASFR']
        n=LASFR.shape[0]
        for j in range(n):
            #labels.append(('dense' if sparsity==0.0 else f"sparse({sparsity})")+" "+f'layer{j+1}')
            labels.append(str(j+1))
            values.append(LASFR[j][-1])
            colors.append(get_color(i,j,n))
    x = np.arange(len(labels))

    fig, ax = plt.subplots(figsize=(8, 6), layout='constrained')
    bars = ax.bar(x, values, color=colors)

    # 设置 X 轴
    ax.set_xticks(x)
    ax.set_xticklabels(labels, ha='right', fontsize=fs)
    ax.set_xlabel('layer index', fontsize=fs)
    # Y 轴字体
    ax.tick_params(axis='y', labelsize=fs)
    # 设定 主刻度 + 每两个主刻度之间一个小刻度
    ax.yaxis.set_major_locator(ticker.AutoLocator())
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))  # 每两个 major 间一个 minor 
    ax.tick_params(axis='y', which='major', length=6)
    ax.tick_params(axis='y', which='minor', length=3)
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())

    # 添加 bar 上文字，并自动扩展边距以包容文字
    # matplotlib ≥3.7 推荐使用 bar_label
    ax.bar_label(bars, fmt='%.4f', padding=3, fontsize=fs)  # 会自动调整 axis limits 
    ax.margins(y=0.10)

    title = f'{args.architecture} on {args.dataset}'
    ax.set_title(title, fontsize=fs)
    # 最后再应用 tight_layout 保证整体不裁剪
    fig.tight_layout()
    # 保存
    plt.savefig(os.path.join(save_path, f'{title}.png'))
    plt.close()

def bar_compare():
    # 定义数据
    labels = ['dense input ANN', 'dense trained ANN - CS-QCFS', 'dense trained ANN - ReLU','dense SNN' ,'sparse input ANN', 'sparse trained ANN - CS-QCFS', 'sparse trained ANN - ReLU','sparse SNN']
    snn_max=get_snn_max()
    ann_qcfs=get_ann_qcfs()
    ann_relu=get_ann_relu()
    values = [baseline[0],ann_qcfs[0],ann_relu[0],snn_max[0],baseline[1],ann_qcfs[1],ann_relu[1],snn_max[1]]
    values=np.array(values)
    colors = ['black','black','black','black','red','red','red','red']

    #save_to_xlsx(dataset,labels,values)

    x = np.arange(len(labels))
    plt.figure(figsize=(8, 5))
    bars = plt.bar(x, values, color=colors)

    plt.xticks(x, labels, rotation=30, ha='right')  # 旋转30°避免重叠 :contentReference[oaicite:1]{index=1}

    ylim=int(min(values)/10)*10 
    plt.ylabel('accuracy(%)')
    plt.ylim(ylim, max(values) + 5)

    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height:.2f}',
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3),  # 标签偏移
                     textcoords='offset points',
                     ha='center', va='bottom')

    title='Accuracy Comparison'
    plt.title(title)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path,f'{title}.png'))
    plt.close()

    
#cannot be used 
def ablation_QCFS(Hpara_name:str,rng, log:bool ):
    fs=1.5*default_fs
    plt.figure()
    title=f"Ablation study of {Hpara_name}"
    for i,spar in enumerate(sparsity_range):
        data=[]
        global_best=list(best_experiments[spar]['snn_max'])

        match Hpara_name:
            case 'lr': var,label=1, f"bs_{global_best[2]}/l_{global_best[3]}"
            case 'bs': var,label=2, f"lr_{global_best[1]}/l_{global_best[3]}"
            case 'l' : var,label=3, f"lr_{global_best[1]}/bs_{global_best[2]}"
            case _   : raise ValueError('invalid hyper parameter name')
        for x in rng:
            global_best[var]=x
            try:
                acc = max(loadmat(filename(*global_best))['snn_acc'].flatten())*100
            except FileNotFoundError:
                acc=np.nan
            except KeyError:
                acc=np.nan
            data.append(acc)
        color=colors[i]
        plt.plot(rng, data, color=color, label=f"s_{spar}/"+label)

    plt.xlabel(Hpara_name,fontsize=fs)
    if log:
        plt.xscale('log')
    plt.xticks(rng, [str(x) for x in rng],fontsize=fs)
    plt.yticks(fontsize=fs)
    plt.ylabel("SNN Best Accuracy (%)",fontsize=fs)
    plt.title(title,fontsize=fs)
    plt.legend(fontsize=fs)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path,f'{title}.png'))
    plt.close()

'''












