import numpy as np
import torch


def generateData(config, images=None, test=False, stim=None):
    # Draw new sample images at random and orthonormalize
    if images is None:
        if test:
            raise Exception('No images provided during testing')
        images = config['rng'].normal(size=[config['data_type']] + config['image_shape']).astype(np.float32)

        for stim in range(config['data_type']):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])
        proj = np.dot(images[0, :], images[1, :])
        images[1, :] -= proj * images[0, :]

        for stim in range(config['data_type']):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])

    # Create input (x) and target (y_rnn) and output temporal mask (y_rnn_mask) matrices
    if test:
        stims = [stim]
        datasetSize = 1
    else:
        stims = config['rng'].randint(config['data_type'], size=[config['batch_size']])
        datasetSize = config['batch_size']

    # gen input data
    imgStim = np.concatenate((images[stims, :],
                              np.float32(config['fixationInput'] * np.ones([datasetSize, 1]))),
                             axis=1)
    imgStim = torch.from_numpy(imgStim)

    # gen output label
    fixationOffset = int(datasetSize * config['fixationPeriod'][1])
    labels = np.zeros((datasetSize * config['tdim'], config['num_rnn_out']))
    labels[0:fixationOffset, config['num_rnn_out'] - 1] = 1.0  # Fixation
    for stCnt, stim in enumerate(stims):
        if stim == 1:
            labels[np.arange(fixationOffset + stCnt, datasetSize * config['tdim'], datasetSize),
            :(config['num_rnn_out'] - 1)] = np.tile(images[stim], (
            np.arange(fixationOffset + stCnt, datasetSize * config['tdim'], datasetSize).shape[0], 1))
    labels = labels.reshape(config['tdim'], datasetSize, config['num_rnn_out'])
    labels = torch.from_numpy(labels)

    label_mask = np.ones([datasetSize * config['tdim']])
    label_mask[0:fixationOffset] = 1.0
    label_mask[fixationOffset:] = 1.0
    label_mask[np.arange(fixationOffset, fixationOffset + datasetSize * int(100 / config['dt']))] = 0.0
    label_mask = torch.from_numpy(label_mask)
    return imgStim, labels, label_mask, images, stims
