import torch
import random
import os
import numpy as np
from scipy.io import savemat, loadmat
from models import VGG, MLP, CHT_Model
from tqdm import tqdm
import sys
sys.path.append('..')
sys.path.append('../QCFS')
from Models.load_data import load_data_mlp
#from exp_comparison import find_best_sparsity
from custom_load_state_dict import vgg16_Hanming_to_CHT_thre, vgg16_to_thre
from models.spikeLayer import SPIKE_PosNeg_layer, SPIKE_PosNeg_layer_BN
from dst_scheduler import DSTScheduler
from arg_parse import Args
from exp_comparison import grid_best_input
from utils import replace_maxpool2d_by_avgpool2d

def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark=False

def get_path(args):
    basic_input=os.path.join("../input",args.architecture,args.dataset)
    if args.architecture == 'MLP':
        sparse=f"s_{args.linear_sparsity}" 
        input_path = os.path.join(basic_input, sparse)
    else:   
        sparse=f"conv_{args.conv_sparsity}/s_0.0"
        method_str = f'd_0.0/onefc_True'
        input_lr, input_bs = grid_best_input(args.architecture, args.dataset, mat_name='res.matSNM200')
        config_str=f'lr_{input_lr}/bs_{input_bs}' 
        input_path = os.path.join(basic_input, sparse, method_str, config_str)
    res_save=os.path.join('..', 'SNM', 'results', args.architecture, args.dataset, sparse)
    return input_path , res_save

def get_firing_rate(model):
    fr=[]
    for name,child in model.named_children():
        if isinstance(child, SPIKE_PosNeg_layer) or isinstance(child, SPIKE_PosNeg_layer_BN):
            fr.append(child.batch_fire_rate)
        else:
            fr.extend(get_firing_rate(child))
    return fr

def main(args):
    input_path, res_save = get_path(args)
    
    device=torch.device(args.device)

    train_loader,test_loader,indim,outdim,hiddim=load_data_mlp(args.dataset,args.bs, args.dim)
    
    statedict=torch.load(os.path.join(input_path, 'best_model.pthSNM_200'), weights_only=True, map_location='cpu')

    if args.architecture != 'MLP':
        if args.conv_sparsity == 0.0:
            ann = VGG.VGG16_optimalThres(outdim, args.one_fc)
            load_res = vgg16_to_thre(statedict, ann, False)
        else:
            ann = CHT_Model.VGG16_CHT_optimalThres(outdim, args.one_fc)
            load_res = vgg16_Hanming_to_CHT_thre(statedict, ann, False)
        num_activations=16 if not args.one_fc else 14  #for VGG-16
    else:
        ann = MLP.MLP_optimalThres(indim, hiddim, outdim)
        ann.load_state_dict(statedict) 
        num_activations=4

    if args.linear_sparsity>0.0 and not args.one_fc:
        pruner_dict=torch.load(os.path.join(input_path,'prunerSNM200.pth'),weights_only=True,map_location='cpu')
        torch.cuda.empty_cache() #重写了load_state_dict方法，自动适应device
        pruner = DSTScheduler(ann, None, alpha=0.3, delta=len(train_loader), sparsity_distribution='uniform', static_topo=False, T_end=100, ignore_linear_layers=False, grad_accumulation_n=1, args=args,state_dict=pruner_dict)
        #在这里注册了backward hook,就不用包装优化器了。
        del pruner_dict
    else:
        pruner = None

    torch.cuda.empty_cache()
    del statedict
    ann = replace_maxpool2d_by_avgpool2d(ann)
    ann = ann.to(device)

    #validate on test set with ANN.
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if len(targets)<args.bs:
                print('dropout last batch') ; continue
            ann.eval()
            inputs, targets = inputs.to(device), targets.to(device)
            if batch_idx == 0:
                ann.init_thresh(inputs)
            outputs = ann(inputs)
            _, predicted = outputs.max(1)
            total += float(targets.size(0))
            correct += float(predicted.eq(targets).sum().item())

    print('Test Accuracy on test set: %.3f' % (100 * correct / total))

    # # find the maximum activation and min activation on training set
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            if len(targets)<args.bs:
                print('dropout last batch') ; continue
            ann.eval()
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = ann(inputs)
            _, predicted = outputs.max(1)
            total += float(targets.size(0))
            correct += float(predicted.eq(targets).sum().item())

    print('Test Accuracy on train set: %.3f' % (100 * correct / total))

    max_activate = ann.max_active
    print(f'max_activate: length:{len(max_activate)}, element size:{max_activate[0].size()}')

    if args.architecture != 'MLP':
        snn = VGG.VGG16_BN_PosNeg_spiking(max_activate, ann)
    else:
        snn = MLP.MLP_PosNeg_spiking(max_activate, ann)

    snn = replace_maxpool2d_by_avgpool2d(snn)
    snn = snn.to(device)

    #test SNN
    total = 0.
    testCorr = [0. for i in range(args.t)]
    sfr_layer_list= [[0. for i in range(args.t)] for j in range(num_activations)]
    with torch.no_grad():
        snn.eval()
        snn.weight_bias_norm()
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            if len(targets)<args.bs:
                print('dropout last batch') ; continue
            snn.init_layer()
            inputs, targets = inputs.to(device), targets.to(device)
            rate_decoding = 0.
            total+=targets.size(0) #size(dim)返回int

            for t in range(args.t):
                spk = snn(inputs, t) #T藏在模型内部了
                rate_decoding += spk
                #这里用了rate_decoding，尽管没有除掉总时间（所以是求和而不是平均），但是分类是找最大值，除不除无所谓
                testCorr[t] += ((targets==rate_decoding.max(1)[1]).sum()).item() #spk.max(1)是在第一个维度（类别维度而不是batch)找最大值，返回[0]值 [1]索引

                fr = get_firing_rate(snn)
                for j,sfr in enumerate(fr):
                    sfr_layer_list[j][t]+=sfr

        for j in range(num_activations):
            sfr_layer_list[j]=[sum(sfr_layer_list[j][:(i+1)])/(i+1) for i in range(len(sfr_layer_list[j]))]

    corr =  np.array(testCorr) / total
    LASFR = np.array(sfr_layer_list)/total
    print(f'Accuracy: {corr*100}%')

    os.makedirs(os.path.join(res_save,f'bs_{args.bs}'), exist_ok=True)
    savemat(os.path.join(res_save,f'bs_{args.bs}' ,'res.mat'), {'snn_acc': corr, 'LASFR': LASFR})
    print('Conversion finished')

if __name__=='__main__':
    args=Args()
    seed_all(args.seed)

    if args.architecture == 'VGG-16':
        if args.conv_sparsity==0.0:
            assert args.linear_sparsity==0.0
        elif args.one_fc : 
            args.linear_sparsity==0.0
        else:
            pass
    
    print(args)

    main(args)

    
