import numpy as np
from godl_task2.cfgs.process_cfg import set_cfgs
from godl_task2.utils.set_seed import setup_seed


def img_gen(config):
    images = config['rng'].normal(size=[config['num_rnn_out'] - 1] + config['image_shape']).astype(np.float32)
    for stim in range(config['num_rnn_out'] - 1):
        images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])
    proj = np.dot(images[0, :], images[1, :])
    images[1, :] -= proj * images[0, :]
    for stim in range(config['num_rnn_out'] - 1):
        images[stim, :] = images[stim, :] / np.linalg.norm(images[stim, :])
    return images


if __name__ == '__main__':
    # get cfgs
    cfgs = set_cfgs()
    print("---cfgs---")
    print(cfgs)
    # set seed
    setup_seed(cfgs['seed'])
    all_imgs = []
    for p in range(cfgs['problems']):
        imgs = img_gen(cfgs)
        all_imgs.append(imgs)
    all_imgs = np.stack(all_imgs, axis=0)
    np.save('all_imgs.npy', all_imgs)
