import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import sys
sys.path.append('..')
from FigureLib import default_fs, snn_t_12, snn_t_3
from get_data import get_baseline, get_snn_data_QCFS, get_snn_data_SNM, get_snn_data_AEC

default_fs = plt.rcParams['font.size']
architectures = ['MLP', 'VGG-16']
datasets = ['CIFAR10', 'CIFAR100']
methods = ['QCFS', 'SNM', 'AEC']

latency_dict = {k:{k1:{m:{} for m in methods} for k1 in datasets} for k in architectures}

# 假设你有 snn_t(ax, sparsity_range, snn_data, baseline, latency=None, max_T=64, title=None)
def plot_snn_t_grid(fs_scaler):
    """
    画一个 4x3 子图网格：
    行: ARCH  DATASET 组合；
    列： 方法 QCFS / SNM / AEC(AEC 暂时空白)
    """

    # 子图数
    nrows = len(architectures) * len(datasets)  # 2x2 = 4
    ncols = len(methods)  # 3
    fig, axs = plt.subplots(nrows, ncols, figsize=(ncols * 10, nrows * 8))  # 可以调大小

    # 遍历每一行（arch/dataset）和列（method）
    for i_arch, arch in enumerate(architectures):
        for j_ds, ds in enumerate(datasets):
            row_idx = i_arch * len(datasets) + j_ds
            for k_m, method in enumerate(methods):
                ax = axs[row_idx, k_m]

                # 确定 sparsity_range
                if arch == 'VGG-16':
                    sparsity_range = (0.0, 0.5)
                else:
                    sparsity_range = (0.0, 0.99)

                # baseline 数据字典，key 是 sparsity
                baseline = get_baseline(arch, ds, sparsity_range, method)

                # snn_data 数据字典
                if method == 'QCFS':
                    xs, snn_data = get_snn_data_QCFS(arch, ds, sparsity_range, 'snn_acc')
                elif method == 'SNM':
                    xs, snn_data = get_snn_data_SNM(arch, ds, sparsity_range, 'snn_acc')   
                elif method == 'AEC':
                    xs, snn_data = get_snn_data_AEC(arch, ds, sparsity_range, 'snn_acc')   

                # 调用 snn_t 画图
                title_sub = f"{arch}-{ds}-method{k_m+1}"
                if method=='AEC':
                    snn_t_3(ax, sparsity_range=sparsity_range, xs_all=xs, snn_data=snn_data, baseline=baseline, max_T=64, title=title_sub, fs_scaler=fs_scaler)
                else:
                    snn_t_12(ax, sparsity_range=sparsity_range, xs_all=xs, snn_data=snn_data, baseline=baseline, max_T=64, title=title_sub, fs_scaler=fs_scaler)

    # 可选共享 legend
    handles = []
    labels = []
    for ax_row in axs:
        for ax in ax_row:
            h, l = ax.get_legend_handles_labels()
            for hh, ll in zip(h, l):
                if ll not in labels:
                    handles.append(hh)
                    labels.append(ll)
    fig.legend(handles, labels, loc='upper center', ncol=len(labels), fontsize=fs_scaler*default_fs)

    plt.subplots_adjust(
    left=0, right=1, bottom=0, top=0.93,
    wspace=0.2, hspace=0.35 )
    #fig.tight_layout(rect=[0, 0, 1, 0.97])
    save_dir = os.path.join('../graphs')  # 根据你想要的位置
    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, 'snn_t_grid_newlat.png')
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved combined plot to {out_path}")

if __name__ == '__main__':
    plot_snn_t_grid(fs_scaler=3)
