import os
import scipy.io
import pandas as pd
import sys
sys.path.append('../')
from get_data import *
from energy import compute_snn_ann_energy_ratio, init_energy
from latency import find_latency

# 假设你有这些列表
arch_list = ['MLP', 'VGG-16']   # 比如 ['arch1','arch2',...]
ds_list = ['CIFAR10', 'CIFAR100']     # 比如 ['ds1','ds2',...]
method_list = ['QCFS', 'SNM', 'AEC'] # 比如 ['m1','m2',...]
sparse_types = ['fc','s']

results = []

for ds in ds_list:

    for arch in arch_list:

        if arch == 'VGG-16':
            sparsity_range = (0.0, 0.5)
        else:
            sparsity_range = (0.0, 0.99)
    
        for method in method_list:

            model, dummy_input, device = init_energy(method, arch, ds, 'cuda:0')

            args_acc = (arch, ds, sparsity_range, 'snn_acc')
            args_LASFR = (arch, ds, sparsity_range, 'LASFR')
            match method:
                case 'QCFS': 
                    ts, snns = get_snn_data_QCFS(*args_acc)
                    ts, LASFRS = get_snn_data_QCFS(*args_LASFR)
                case 'SNM': 
                    ts, snns = get_snn_data_SNM(*args_acc)
                    ts, LASFRS = get_snn_data_SNM(*args_LASFR)
                case 'AEC': 
                    ts, snns = get_snn_data_AEC(*args_acc)
                    ts, LASFRS = get_snn_data_AEC(*args_LASFR)
            ts = np.array(ts).flatten()

            for sparse in sparse_types:
                sparsity = sparsity_range[int(sparse=='s')]
                spar_label = sparse+f'({sparsity})' if sparse=='s' else sparse

                snn = np.array(snns[sparsity]).reshape(-1)
                LASFR = LASFRS[sparsity]
                eng_t = compute_snn_ann_energy_ratio(ts, LASFR, 1-sparsity ,model, dummy_input, device)['energy_snn']

            
                if method!='AEC':
                    latency, idx_lat = find_latency(ts, snn, 10)
                    eng_lat = eng_t[idx_lat]
                else:
                    idx_max = int(np.argmax(snn))
                    latency = ts[idx_max]
                    if latency>64:
                        latency = 64
                        idx_max = np.where(ts==64)[0]
                    eng_lat = eng_t[idx_max]

            
                results.append({
                    'arch': arch,
                    'ds': ds,
                    'method': method,
                    'connectivity': spar_label,
                    'T': latency,
                    'energy': eng_lat
                })

# 构造成 DataFrame
df = pd.DataFrame(results)

# 确保列顺序
desired_cols = ['ds', 'arch', 'method', 'connectivity', 'T',  'energy']
df = df[desired_cols]

# 输出到 Excel
output_path = "./Table.xlsx"
df.to_excel(output_path, index=False)
print(f"Saved summary to {output_path}")
