
from utils.get_Mnist import getdata_train, getdata_test
import brainpy as bp
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt

def moveup(st):
    spikes = st.copy()
    spikes[:, 0] = 0
    spikes[:, 4] = 0
    spikes[:, 7] = 0

    spikes[:, 56] = 0
    spikes[:, 63] = 0
    return spikes

def training_sim(BNNnet, inputdata, trainer, epoch = 1):

    for i in range(epoch):
        BNNnet.set_para(mod = True)
        trainer.run(inputs = inputdata, reset_state=True)

    spikes = moveup(trainer.mon["n.spike"][:, 0, :])
    return spikes

def repeat(X, during):
    res = np.zeros((during, ) + X.shape)
    for i in range(during):
        res[i] = X
    return res

def training_int(BNNnet, during, trainer, path, batch_size, I_max, save, upto=2):
    num = int(np.ceil(12665 / 256))
    with tqdm(total=num) as pbar:
        pbar.set_description('trainning:')
    
        for X, Y in getdata_train(path, batch_size, I_max=I_max, upto=upto):
            X = repeat(X, during) 
            BNNnet.set_para(mod = True)
            trainer.run(inputs = X, reset_state=True)
            BNNnet.homeostatic(X.shape[1])
            pbar.update(1)
    states = { 'net': BNNnet.state_dict() }
    bp.checkpoints.save_pytree(os.path.join("./", save), states)
    states = 0


def label_int(BNNnet, during, trainer, path, I_max, save, upto=2):
    idx = []
    for i in range(28):
        for j in range(28):
            idx.append((i*2+1)*57 + (j*2+1))
    neuron_records = np.zeros((28 * 28 , upto))
    label_records = np.zeros((28 * 28 , upto))
    neuron_label = np.zeros(28 * 28)

    num = int(np.ceil(12665 / 256))
    with tqdm(total=num) as pbar:
        pbar.set_description('label:')

        for X, Y in getdata_train(path, 256, I_max=I_max, upto=upto):
            X = repeat(X, during)
            BNNnet.set_para(mod = False)
            trainer.run(inputs = X, reset_state=True)
            neuron_records[:, Y] += np.sum(trainer.mon["n.spike"][:,:,idx], axis=0).transpose(1,0)
            label_records[:, Y] += 1
            pbar.update(1)

    for neuron in range(28 * 28):
        neuron_label[neuron] = (neuron_records[neuron]/label_records[neuron]).argmax() 
    # plt.pcolormesh(neuron_label.reshape((28,28)))
    # plt.show()
    np.save(save, neuron_label)
     
    return neuron_label


def testing_int(BNNnet, during, trainer, path, I_max, neuron_label, upto=2, func = "test"):
    idx = []
    for i in range(28):
        for j in range(28):
            idx.append((i*2+1)*57 + (j*2+1))

    acc_num = 0
    total_num = 0
    num = int(np.ceil(2115 / 256))
    if func == "test":
        getdata = getdata_test
    else:
        getdata = getdata_train

    with tqdm(total=num) as pbar:
        pbar.set_description('tesing:')
        for X, Y in getdata(path, 256, I_max, upto=upto):
            X = repeat(X, during)
            BNNnet.set_para(mod = False)
            trainer.run(inputs = X, reset_state=True)

            spike_sum = np.sum(trainer.mon["n.spike"][:,:,idx], axis=0)
            
            for b in range(X.shape[1]):
                l = np.zeros(upto)
                s = np.zeros(upto)
                for neuron in range(28 * 28):
                    l[int(neuron_label[neuron])] += spike_sum[b, neuron]
                    s[int(neuron_label[neuron])]+=1
                l /= s
                if l.argmax() == Y[b]:
                    acc_num += 1
                total_num += 1 
            pbar.update(1)
    return acc_num / total_num
