import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from typing import List, Tuple, Optional

def ablation_SNM(
    architecture: str,
    dataset: str,
    bs_list: List[int],
    save_path: Optional[str] = None,
    verbose: bool = False
):
    """
    对指定的 bs_list，画出 snn_max vs bs 的关系图，并标记最好的 bs（按 snn_max）。
    使用 find_best_bs 的逻辑加载 res.mat 获取 snn_max。
    
    参数：
      architecture, dataset: 如你之前的设定
      bs_list: 一组 batch size 值，例如 [32,64,128,256,512] 或其他
      save_path: 若给定，就把图保存；否则直接显示
      verbose: 是否打印过程信息
    """

    # 内部 mat_path 和稀疏规则与 find_best_bs 保持一致
    def mat_path(architecture: str, dataset: str, bs: int) -> str:
        if architecture == 'MLP':
            sparse = "s_0.99"
        else:
            sparse = "conv_0.5/s_0.0"
        res_save = os.path.join(
            '../',
            'SNM',
            'results',
            architecture,
            dataset,
            sparse
        )
        return os.path.join(res_save, f'bs_{bs}', 'res.mat')

    # 收集数据
    bs_vals = []
    snn_max_vals = []

    for bs in bs_list:
        file_name = mat_path(architecture, dataset, bs)
        if not os.path.exists(file_name):
            if verbose:
                print(f"[跳过] 文件不存在: {file_name}")
            continue
        try:
            res = loadmat(file_name)
        except Exception as e:
            if verbose:
                print(f"[错误] 加载 {file_name} 时出错: {e}")
            continue

        if 'snn_acc' not in res:
            if verbose:
                print(f"[跳过] 'snn_acc' 不在 {file_name}")
            continue

        snn_acc = res['snn_acc'].flatten()
        try:
            max_snn = float(np.max(snn_acc))
        except Exception as e:
            if verbose:
                print(f"[错误] 计算 max_snn 失败: {file_name}, {e}")
            continue

        if verbose:
            print(f"bs={bs}, snn_max={max_snn:.6f}")

        bs_vals.append(bs)
        snn_max_vals.append(max_snn)

    if not bs_vals:
        print("没有任何有效数据可画图")
        return

    # 找最好的 bs
    # 最好是 bs_vals 对应的那个 max 值最大的
    idx_best = int(np.argmax(snn_max_vals))
    best_bs = bs_vals[idx_best]
    best_snn_max = snn_max_vals[idx_best]

    # 绘图
    plt.figure(figsize=(8, 6))
    plt.plot(bs_vals, snn_max_vals, marker='o', linestyle='-', label='snn_max')
    # 标记最佳点
    plt.scatter([best_bs], [best_snn_max], color='red', s=100, label=f'best bs = {best_bs}')

    plt.xlabel("Batch size (bs)")
    plt.semilogx()
    plt.ylabel("snn_max")
    title = f"snn_max vs bs (arch={architecture}, dataset={dataset})"
    plt.title(title)
    plt.grid(True)
    plt.legend()

    # 如果要保存
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        if verbose:
            print(f"图已保存到 {save_path}")
        plt.close()
    else:
        plt.show()

if __name__=='__main__':
    bs_list=[32, 64, 128, 256, 512, 5000]
    for arch in ['MLP', 'VGG-16']:
        for ds in ['CIFAR10', 'CIFAR100']:
            save_path = os.path.join('../graphs',arch, ds, 'ab_SNM bs.png')
            ablation_SNM(arch,ds, bs_list, save_path, True)