from utils.training import training_sim, training_int, testing_int, label_int 
import os
import brainpy.math as bm
import brainpy as bp
from model.intelligent_eva import M_2D
from utils.get_Mnist import getdata_test, getdata_train
from utils.training import repeat
import numpy as np
import pickle
from tqdm import tqdm


def get_dataset(BNNnet, trainer, neuron_label, type):

    idx = []
    for i in range(28):
        for j in range(28):
            idx.append((i*2+1)*57 + (j*2+1))
    test = []
    ibatch = 0
    acc_num = 0
    total_num = 0

    if type == "train":
        function = getdata_train
        num = 60000
    else:
        function = getdata_test
        num = 10000
    
    num = int(np.ceil(num / 256))
    with tqdm(total=num) as pbar:
        pbar.set_description('label:')
        for X, Y in function("./data/", 256, upto=10, I_max=64):
            X = repeat(X, 60) 
            BNNnet.set_para(mod = False)
            trainer.run(inputs = X, reset_state=True)

            spike_sum = np.sum(trainer.mon["n.spike"][:,:,idx], axis=0)#trainer.mon["n.spike"]:(30, 256, 400)

            for i in range(trainer.mon["n.spike"][:,:,idx].shape[1]):
                temp =[]
                temp.append(trainer.mon["n.spike"][:,i,idx])
                temp.append(int(Y[i]))
                test.append(temp)
            
            for b in range(X.shape[1]):
                l = np.zeros(10)
                s = np.zeros(10)

                for neuron in range(400):
                    l[int(neuron_label[neuron])] += spike_sum[b, neuron]
                    s[int(neuron_label[neuron])]+=1
                l /= s
                if l.argmax() == Y[b]:
                    acc_num += 1
                
                total_num += 1 
            pbar.update(1)
        test_acc = acc_num/total_num
        print(type + ':After STDP test_acc', test_acc)  

        with open('../intelligence_expansion/dataset/' + type +  '.npy', "wb") as f:
            pickle.dump(test, f)



def train_10_class():
    # os.environ['CUDA_VISIBLE_DEVICES'] = "1"
    bm.set_environment(bm.training_mode)
    net = M_2D(tau=5, 
                alpha=0.96, beta=0.6,
                T_dur=0.01, K=5, delay=12
                )
    trainer = bp.DSRunner(net, monitors=['n.spike'], data_first_axis = 'T',  progress_bar=False)
    sim_T=60
    path = "./data/"
    batch_size = 256
    I_max = 64
    save_weight = "./weight/class_10.bp"
    save_label = "./weight/class_10.npy"
    save_data = "./data/"


    training_int(BNNnet=net, during=sim_T, 
                trainer=trainer, path=path, 
                batch_size=batch_size, I_max=I_max,
                save=save_weight, upto = 10)
    neuron_label=label_int(
                BNNnet=net, during=sim_T, 
                trainer=trainer, path=path, 
                I_max=I_max, save=save_label, upto = 10)
    train_res = testing_int(BNNnet=net, during=sim_T, trainer=trainer, path=path, 
                     I_max = I_max, neuron_label=neuron_label, upto=10, func="train")
    test_res = testing_int(BNNnet=net, during=sim_T, trainer=trainer, path=path, 
                     I_max = I_max, neuron_label=neuron_label, upto=10, func="test")
    print("training acc", train_res)
    print("testing acc", test_res)
    
    
    # trainer = bp.DSRunner(net, monitors=['n.spike'], data_first_axis = 'T',  progress_bar=False)
    # state_dict = bp.checkpoints.load_pytree("./weight/class_10.bp")
    # net.load_state_dict(state_dict['net'])
    # state_dict = None
    # neuron_label = np.load("./weight/class_10.bp.npy")
    
    get_dataset(net, trainer, neuron_label, "train")
    get_dataset(net, trainer, neuron_label, "test")
    
train_10_class()