from latency import find_latency
import matplotlib.pyplot as plt
from get_data import *
import sys
sys.path.append('../')
from MASFR import compute_MASFR_from_LASFR
from energy import init_energy
from scipy.stats import wilcoxon
import os
import numpy as np
from scipy.io import savemat, loadmat

# 假设你有这些列表
arch_list = ['MLP', 'VGG-16']
ds_list = ['CIFAR10', 'CIFAR100']
method_list = ['QCFS', 'SNM']
sparse_types = ['fc','s']

# 输出目录
os.makedirs('./LASFR', exist_ok=True)

# colors: i=0 黑色, i=1 红色
colors = ['black', 'red']

lrs_1 = [0.05, 0.01, 0.005, 0.001, 0.0005]
bss_1 = [32, 64, 128]
ls_1  = [2, 4, 8, 16, 32]

bss_2 = [32, 64, 128, 256, 512]

def main(name:str):
    assert name in ['MASFR', 'LASFR']

    t_accs = [[], []]
    t_sfrs = [[], []]

    def collet_t(sparsity_range, name, snns, LASFRS):
        for i,sparse in enumerate(sparse_types):
            sparsity = sparsity_range[i]
            snn = np.array(snns[sparsity]).reshape(-1)

            LASFR = LASFRS[sparsity]
            if name=='MASFR':
                MASFR = compute_MASFR_from_LASFR(LASFR, model, dummy_input, device)
            else:
                MASFR = LASFR[-1]

            t1 = find_latency(ts, snn)[0]
            t2 = find_latency(ts, MASFR)[0]
            print(t1, t2)
            t_accs[i].append(t1)
            t_sfrs[i].append(t2)

    for arch in arch_list:

        if arch == 'VGG-16':
            sparsity_range = (0.0, 0.5)
        else:
            sparsity_range = (0.0, 0.99)
        
        for ds in ds_list:
            for method in method_list:
                print(f'=========={arch} {ds} {method}==========')
                model, dummy_input, device = init_energy(method, arch, ds, 'cuda:0')

                args_acc = (arch, ds, sparsity_range, 'snn_acc')
                args_LASFR = (arch, ds, sparsity_range, 'LASFR')
                match method:
                    case 'QCFS': 
                        for lr in lrs_1:
                            for bs in bss_1:
                                for l in ls_1:
                                    if arch=='VGG-16' and ds=='CIFAR10' and lr==0.05 and bs==32 and l==2: break
                                    ts, snns = get_param_data_QCFS(*args_acc, lr, bs, l)
                                    ts, LASFRS = get_param_data_QCFS(*args_LASFR, lr, bs, l)
                                    collet_t(sparsity_range, name, snns, LASFRS)
                    case 'SNM': 
                        for bs in bss_2:
                            ts, snns = get_param_data_SNM(*args_acc, bs)
                            ts, LASFRS = get_param_data_SNM(*args_LASFR, bs)
                            collet_t(sparsity_range, name, snns, LASFRS)


    t_accs, t_sfrs = np.array(t_accs), np.array(t_sfrs)
    mask = t_accs <= 64   # 生成布尔掩码
    mask = mask[0] & mask[1]

    # 用 mask 同时筛选两者
    t_sfrs = [row[mask] for row in t_sfrs]
    t_accs = [row[mask] for row in t_accs]
    print(len(t_accs[0]))
    savemat('./t1_t2.mat', {'t_accs':np.array(t_accs), 't_sfrs':np.array(t_sfrs)})



if __name__=='__main__':
    main('MASFR')