import matplotlib
matplotlib.use('Agg', force=True)
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from scipy.io import loadmat
import os
from adjustText import adjust_text
import pandas as pd
import sys
sys.path.append("..")
from exp_comparison import grid_best_AEC

colors=['black','red']
conv_sparsity_range=(0.0, 0.5)

default_fs = plt.rcParams['font.size']


def ablation_AEC(ds, param_name, param_range, log=False, arch='VGG-16', verbose=False):
    """
    AEC 的 ablation plot 按照'snn_max'的指标。
    - ds: dataset 名称（string）
    - param_name: 'lr' 或 'bs'
    - param_range: 要扫描的取值列表（例如 [1e-3,1e-4,1e-5] 或 [32,64,128]）
    - log: x 轴是否用对数刻度（仅在 param_name == 'lr' 时常用）
    - arch: 'VGG-16' 或 'MLP'
    - verbose: 找不到文件或出错时是否打印更多信息

    画每个 conv_sparsity 的曲线。
    """
    fs = 1.5 * default_fs
    title = f"ab_AEC {param_name}"
    fig, ax = plt.subplots(figsize=(8, 6))
    colors = ('black', 'red')

    def mat_path(arch, ds, lr, bs, cur_sparsity):
        res_path = "../AEC/results"
        if arch == 'MLP':
            sparse = f"s_{cur_sparsity}"
        else:
            sparse = f"conv_{cur_sparsity}/s_0.0"
        config = f'lr_{lr}/bs_{bs}'
        return os.path.join(res_path, arch, ds, sparse, config, 'res.mat')

    # sparsity 值
    sparsity = 0.5 if arch == 'VGG-16' else 0.99
    conv_sparsity_range = [0.0, sparsity]

    # 从 grid_best_AEC 获取最佳配置（尽量用 snn_8，如果没有则退为 snn_max）
    best_exps = grid_best_AEC(arch, ds, verbose=False, compute_energy=False)

    global_best = list(best_exps[sparsity]['snn_8'])

    # global_best 是 list [lr, bs]，我们会在循环中修改
    for i, spar in enumerate(conv_sparsity_range):
        data = []

        # var 表示 global_best 的索引：lr 在 index 0，bs 在 index 1
        spar_label = 'dense' if sparsity>0.0 else 'sparse' 
        if param_name == 'lr':
            var, label = 0, f"bs_{global_best[1]}"
        elif param_name == 'bs':
            var, label = 1, f"lr_{global_best[0]}"
        else:
            raise ValueError("invalid hyper parameter name (must be 'lr' or 'bs')")
        label = spar_label + label

        for x in param_range:
            global_best[var] = x
            lr, bs = global_best[0], global_best[1]

            # 构造 AEC 的路径（与 grid_best_AEC 中 mat_path 保持一致）
            filename = mat_path(arch, ds, lr, bs, spar)

            if not os.path.exists(filename):
                if verbose:
                    print(f'file {filename} not found. appending nan.')
                data.append(float('nan'))
                continue

            try:
                mat = loadmat(filename)
                acc = float(np.max(np.array(mat['snn_acc']).flatten())) * 100.0
                data.append(acc)
            except Exception as e:
                if verbose:
                    print(f"Error loading {filename}: {e}")
                data.append(float('nan'))

        color = colors[i]
        ax.plot(param_range, data, color=color, marker='o', label=label)

    ax.set_xlabel(param_name, fontsize=fs)
    if log:
        ax.set_xscale('log')
    try:
        ax.set_xticks(param_range)
        ax.set_xticklabels([str(x) for x in param_range], fontsize=fs)
    except Exception:
        pass
    ax.tick_params(axis='y', labelsize=fs)
    ax.set_ylabel("AEC SNN accuracy (%)", fontsize=fs)
    ax.set_title(title, fontsize=fs)
    ax.legend(fontsize=fs)
    ax.grid(True)

    save_path = os.path.join(f'../graphs/{arch}', ds)
    os.makedirs(save_path, exist_ok=True)
    fig.tight_layout()
    save_file = os.path.join(save_path, f'{title}.png')
    fig.savefig(save_file)
    if verbose:
        print(f"Saved ablation figure to {save_file}")
    plt.close(fig)
