import numpy as np
import torch


def generateData(config, images=None, test=False, stim=None):
    # Draw new sample images at random and orthonormalize
    if images is None:
        if test:
            raise Exception('No images provided during testing')
        images = config['rng'].normal(size=[config['num_rnn_out'] - 1] + config['image_shape']).astype(np.float32)
        for stim in range(config['num_rnn_out'] - 1):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])
        proj = np.dot(images[0, :], images[1, :])
        images[1, :] -= proj * images[0, :]
        for stim in range(config['num_rnn_out'] - 1):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])

    if test:
        stims = [stim]
        datasetSize = 1
    else:
        stims = config['rng'].randint(config['num_rnn_out'] - 1, size=[config['batch_size']])
        datasetSize = config['batch_size']

    # gen input data
    imgStim = np.concatenate((images[stims, :],
                              np.float32(config['fixationInput'] * np.ones([datasetSize, 1]))),
                             axis=1)
    imgStim = torch.from_numpy(imgStim)

    # gen output label
    fixationOffset = int(datasetSize * config['fixationPeriod'][1])
    labels = np.zeros((datasetSize * config['tdim'], config['num_rnn_out']))
    labels[0:fixationOffset, config['num_rnn_out'] - 1] = 1.0  # Fixation
    for stCnt, stim in enumerate(stims):
        labels[np.arange(fixationOffset + stCnt, datasetSize * config['tdim'], datasetSize), stim] = 1.0
    labels = labels.reshape(config['tdim'], datasetSize, config['num_rnn_out'])

    label_mask = np.ones([datasetSize * config['tdim']])
    label_mask[0:fixationOffset] = 1.0
    label_mask[fixationOffset:] = 1.0
    label_mask[np.arange(fixationOffset, fixationOffset + datasetSize * int(100 / config['dt']))] = 0.0
    labels = torch.from_numpy(labels)
    label_mask = torch.from_numpy(label_mask)
    return imgStim, labels, label_mask, images, stims
