import os
from scipy.io import loadmat
import numpy as np
import sys
sys.path.append('../')
from exp_comparison import grid_best_input, grid_best_QCFS, grid_best_SNM, grid_best_AEC

def get_baseline(arch, ds, sparsity_range, method):
    baseline = {}
    for spar in sparsity_range:
        # 构造路径
        basic_input = os.path.join("../input", arch, ds)
        if arch == 'MLP':
            sparse_str = f"s_{spar}"
            input_path = os.path.join(basic_input, sparse_str)
        else:
            sparse_str = f"conv_{spar}/s_0.0"
            method_str = "d_0.0/onefc_True"
            input_lr, input_bs = grid_best_input(arch, ds)
            config_str = f"lr_{input_lr}/bs_{input_bs}"
            input_path = os.path.join(basic_input, sparse_str, method_str, config_str)

        # 加载 baseline
        name = 'res.matSNM200' if (method=='SNM' and arch=='VGG-16') else 'res.mat'
        resmat_path = os.path.join(input_path, name)
        if os.path.exists(resmat_path):
            try:
                resm = loadmat(resmat_path)
                # 假设键是 'acc_val'
                acc_vals = resm.get('acc_val')
                if acc_vals is not None:
                    baseline[spar] = float(np.max(acc_vals))
                else:
                    baseline[spar] = np.nan
            except Exception as e:
                baseline[spar] = np.nan
                print(f"[WARN] baseline loadmat failed for {resmat_path}: {e}")
        else:
            baseline[spar] = np.nan
            print(f"[WARN] baseline file not found: {resmat_path}")
    
    return baseline 

def get_snn_data_QCFS(arch, ds, sparsity_range, key, metric='snn_max'):
    snn_data = {}
    assert key in ('snn_acc', 'LASFR')
    
    best = grid_best_QCFS(arch, ds, verbose=False, compute_energy=False)[sparsity_range[1]][metric]
    for spar in sparsity_range:
        # best 是 (sparsity, lr, bs, l)
        _, lr, bs, l = best
        res_path = '../QCFS/results'
        if arch == 'MLP':
            sparse_str = f"s_{spar}"
        else:
            sparse_str = f"conv_{spar}/s_0.0"  
        tr_save_name = os.path.join(f'{arch}/{ds}', sparse_str, f'lr_{lr}/bs_{bs}/L_{l}')
        if arch == 'VGG-16':
            final_model_path = os.path.join(res_path, tr_save_name, 'ft', 'res.mat')
        else:
            final_model_path = os.path.join(res_path, tr_save_name, 'res.mat')
        if os.path.exists(final_model_path):
            try:
                resqc = loadmat(final_model_path)
                data = resqc.get(key)
                if data is not None:
                    snn_data[spar] = np.asarray(data)
                else:
                    snn_data[spar] = np.array([])
            except Exception as e:
                snn_data[spar] = np.array([])
                print(f"[WARN] QCFS loadmat failed for {final_model_path}: {e}")
        else:
            snn_data[spar] = np.array([])
            print(f"[WARN] QCFS file not found: {final_model_path}")
    
    if key=='snn_acc':
        T = len(snn_data[0.0].flatten())
    else:
        T = snn_data[0.0].shape[1]
        
    return np.arange(1, T+1), snn_data

def get_snn_data_SNM(arch, ds, sparsity_range, key):
    snn_data = {}
    assert key in ('snn_acc', 'LASFR')

    bs, _ = grid_best_SNM(arch, ds, False)
    for spar in sparsity_range:
        if arch == 'MLP':
            sparse_str = f"s_{spar}"
        else:
            sparse_str = f"conv_{spar}/s_0.0"
        res_save = os.path.join('../', 'SNM', 'results', arch, ds, sparse_str)
        
        fn = os.path.join(res_save, f'bs_{bs}', 'res.mat')
        if os.path.exists(fn):
            try:
                ressnm = loadmat(fn)
                data = ressnm.get(key)
                if data is not None:
                    snn_data[spar] = np.asarray(data)
                else:
                    snn_data[spar] = np.array([])
            except Exception as e:
                snn_data[spar] = np.array([])
                print(f"[WARN] SNM loadmat failed for {fn}: {e}")
        else:
            snn_data[spar] = np.array([])
            print(f"[WARN] SNM file not found: {fn}")

    if key=='snn_acc':
        T = len(snn_data[0.0].flatten())
    else:
        T = snn_data[0.0].shape[1]

    return np.arange(1, T+1), snn_data

def get_snn_data_AEC(arch, ds, sparsity_range, key, metric='snn_max'):
    res_path = "../AEC/results"
    x = np.array([2, 4, 8, 16, 32, 64, 128])
    snn_data = {k: [] for k in sparsity_range}

    # 尝试从 grid_best_AEC 获取每个 sparsity 的 lr, bs；失败则使用 fallback
    best_exp = grid_best_AEC(arch, ds, verbose=False, compute_energy=False).get(sparsity_range[1]).get(metric)

    # 内部路径构造函数
    def mat_path(arch, ds, spar, lr, bs):
        if arch == 'MLP':
            sparse_part = f"s_{spar}"
        else:
            sparse_part = f"conv_{spar}/s_0.0"
        return os.path.join(res_path, arch, ds, sparse_part, f"lr_{lr}", f"bs_{bs}", "res.mat")

    for spar in sparsity_range:
        lr, bs = best_exp[1], best_exp[2]

        fn = mat_path(arch, ds, spar, lr, bs)
        if not os.path.exists(fn):
            print(f'not exist file: {fn}')
            continue
        ressnm = loadmat(fn)
        arr = ressnm.get(key)
        snn_data[spar] = arr

    return x, {k: np.array(v) for k, v in snn_data.items()}

def get_param_data_QCFS(arch, ds, sparsity_range, key, lr, bs, l):
    snn_data = {}
    assert key in ('snn_acc', 'LASFR')
    
    for spar in sparsity_range:
        res_path = '../QCFS/results'
        if arch == 'MLP':
            sparse_str = f"s_{spar}"
        else:
            sparse_str = f"conv_{spar}/s_0.0"  
        tr_save_name = os.path.join(f'{arch}/{ds}', sparse_str, f'lr_{lr}/bs_{bs}/L_{l}')
        if arch == 'VGG-16':
            final_model_path = os.path.join(res_path, tr_save_name, 'ft', 'res.mat')
        else:
            final_model_path = os.path.join(res_path, tr_save_name, 'res.mat')
        if os.path.exists(final_model_path):
            try:
                resqc = loadmat(final_model_path)
                data = resqc.get(key)
                if data is not None:
                    snn_data[spar] = np.asarray(data)
                else:
                    snn_data[spar] = np.array([])
            except Exception as e:
                snn_data[spar] = np.array([])
                print(f"[WARN] QCFS loadmat failed for {final_model_path}: {e}")
        else:
            snn_data[spar] = np.array([])
            print(f"[WARN] QCFS file not found: {final_model_path}")
    
    if key=='snn_acc':
        T = len(snn_data[0.0].flatten())
    else:
        T = snn_data[0.0].shape[1]
        
    return np.arange(1, T+1), snn_data

def get_param_data_SNM(arch, ds, sparsity_range, key, bs):
    snn_data = {}
    assert key in ('snn_acc', 'LASFR')
    for spar in sparsity_range:
        if arch == 'MLP':
            sparse_str = f"s_{spar}"
        else:
            sparse_str = f"conv_{spar}/s_0.0"
        res_save = os.path.join('../', 'SNM', 'results', arch, ds, sparse_str)
        
        fn = os.path.join(res_save, f'bs_{bs}', 'res.mat')
        if os.path.exists(fn):
            try:
                ressnm = loadmat(fn)
                data = ressnm.get(key)
                if data is not None:
                    snn_data[spar] = np.asarray(data)
                else:
                    snn_data[spar] = np.array([])
            except Exception as e:
                snn_data[spar] = np.array([])
                print(f"[WARN] SNM loadmat failed for {fn}: {e}")
        else:
            snn_data[spar] = np.array([])
            print(f"[WARN] SNM file not found: {fn}")

    if key=='snn_acc':
        T = len(snn_data[0.0].flatten())
    else:
        T = snn_data[0.0].shape[1]

    return np.arange(1, T+1), snn_data


if __name__=='__main__':
    print(get_snn_data_AEC('MLP', 'CIFAR10', (0.0, 0.99), 'LASFR')[1][0.99].shape)













