
import brainpy_datasets as bd
import brainpy as bp
import brainpy.math as bm
import numpy as np


def getdata_train(path, batch_size, I_max=128, upto=2):
    s = np.zeros((57,57))
    idx = []
    or_idx = []
    for i in range(28):
        for j in range(28):
            s[i*2+1, j*2+1]=1
            idx.append((i*2+1)*57 + (j*2+1))
            or_idx.append(i*28+j)

    train_data = bd.vision.MNIST(path, split='train', download=True)
    x_train = bm.asarray(train_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28)
    y_train = bm.asarray(train_data.targets, dtype=bm.int_)
    idxd = bm.where(y_train < upto)
    y_train = y_train[idxd]
    x_train = x_train[idxd]


    for i in range(0, x_train.shape[0], batch_size):
        X = x_train[i: i + batch_size]
        Y = y_train[i: i + batch_size]
        posX = np.zeros((X.shape[0], 57 * 57))
        posX[:, idx] = X[:, or_idx] * I_max
        yield posX, Y

def getdata_test(path, batch_size, I_max = 128, upto=2):
    s = np.zeros((57,57))
    idx = []
    or_idx = []
    for i in range(28):
        for j in range(28):
            s[i*2+1, j*2+1]=1
            idx.append((i*2+1)*57 + (j*2+1))
            or_idx.append(i*28+j)

    test_data = bd.vision.MNIST(path, split='test', download=True)

    x_test = bm.asarray(test_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28)
    y_test = bm.asarray(test_data.targets, dtype=bm.int_)

    idxd = bm.where(y_test < upto)
    y_test = y_test[idxd]
    x_test = x_test[idxd]

    for i in range(0, x_test.shape[0], batch_size):
        X = x_test[i: i + batch_size]
        Y = y_test[i: i + batch_size]
        posX = np.zeros((X.shape[0], 57 * 57))
        posX[:, idx] = X[:, or_idx] * I_max
        yield posX, Y