import numpy as np
import torch


def generateData(config, images=None, test=False, stim=None, contexts=None):
    # Draw new sample images at random and orthonormalize
    if images is None:
        if test:
            raise Exception('No images provided during testing')
        images = config['rng'].normal(size=[config['num_rnn_out'] - 1] + config['image_shape']).astype(np.float32)
        for stim in range(config['num_rnn_out'] - 1):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])
        proj = np.dot(images[0, :], images[1, :])
        images[1, :] -= proj * images[0, :]
        for stim in range(config['num_rnn_out'] - 1):
            images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])

    # Generate context if not provided
    if contexts is None:
        contexts = config['rng'].choice([0, 1], size=config['batch_size'])  # Assuming two contexts: 0 and 1

    # Create input (x) and target (y_rnn) and output temporal mask (y_rnn_mask) matrices
    if test:
        stims = [stim]
        datasetSize = 1
    else:
        stims = config['rng'].randint(config['num_rnn_out'] - 1, size=[config['batch_size']])
        datasetSize = config['batch_size']

    # Generate input data
    imgStim = np.concatenate((images[stims, :],
                              np.float32(config['fixationInput'] * np.ones([datasetSize, 1]))),
                             axis=1)

    # Expand context to match imgStim dimensions and concatenate
    contexts_expanded = np.expand_dims(contexts, axis=0)
    imgStim = np.concatenate((imgStim, contexts_expanded), axis=1)
    imgStim = torch.from_numpy(imgStim)

    # Generate output labels based on context
    fixationOffset = int(datasetSize * config['fixationPeriod'][1])

    labels = np.zeros((datasetSize * config['tdim'], config['num_rnn_out']))
    labels[0:fixationOffset, config['num_rnn_out'] - 1] = 1.0  # Fixation
    for stCnt, (stim, context) in enumerate(zip(stims, contexts)):
        if context == 0:
            labels[np.arange(fixationOffset + stCnt, datasetSize * config['tdim'], datasetSize), stim] = 1.0
        elif context == 1:
            # Assuming that in context 1 the association is reversed
            labels[np.arange(fixationOffset + stCnt, datasetSize * config['tdim'], datasetSize), (stim + 1) % (
                    config['num_rnn_out'] - 1)] = 1.0
        else:
            raise NotImplementedError('More than two contexts need specific implementation')
    labels = labels.reshape(config['tdim'], datasetSize, config['num_rnn_out'])
    labels = torch.from_numpy(labels)

    label_mask = np.ones([datasetSize * config['tdim']])
    label_mask[0:fixationOffset] = 1.0
    label_mask[fixationOffset:] = 1.0
    label_mask[np.arange(fixationOffset, fixationOffset + datasetSize * int(100 / config['dt']))] = 0.0
    label_mask = torch.from_numpy(label_mask)

    return imgStim, labels, label_mask, images, stims
