import numpy as np
import pickle

color_index = {'red': 0,
          'green': 1,
          'blue' : 2, 
          'cyan' : 3, 
          'magenta': 4, 
          'yellow': 5
         }

n_object = 10

for traintest in ['train', 'test']:
    if traintest == 'train':
        dataset_size = 128000
    else:
        dataset_size = 12800

    data = []
    labelvecs = []

    for n in range(3, 8):
        tmp = np.load('data/shapes/{}_{}.npy'.format(n, traintest))

        shapevec = np.zeros((len(tmp), 5))
        shapevec[:, n-3] = 1

        color = pickle.load(open('data/shapes/colors_{}_{}.pkl'.format(n, traintest), 'rb'))
        color_indices = np.array([color_index[v] for v in color])

        colorvec = np.zeros((len(tmp), len(color_index)))
        colorvec[np.arange(len(colorvec)), color_indices] = 1

        labelvec = np.concatenate((colorvec, shapevec), axis=1)


        data.append(tmp)
        labelvecs.append(labelvec)

    data = np.array(data)
    labelvecs = np.array(labelvecs)

    data = data.reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3], data.shape[4], data.shape[5]))
    data = data[:, :, :, :, :3]

    labelvecs = labelvecs.reshape((labelvecs.shape[0]*labelvecs.shape[1], labelvecs.shape[2]))

    data_colorblind = data[:, 1, :, :, :]
    data_shapeblind = data[:, 2, :, :, :]
    data = data[:, 0, :, :, :]

    X = []
    X_colorblind = []
    X_shapeblind = []
    data_labelvecs = []

    for _ in range(n_object):
        idx = np.random.randint(len(data), size=dataset_size)
        X.append(data[idx])
        X_colorblind.append(data_colorblind[idx])
        X_shapeblind.append(data_shapeblind[idx])
        data_labelvecs.append(labelvecs[idx])

    X = np.array(X)
    X = X.reshape((X.shape[1], X.shape[2], X.shape[0]*X.shape[3], X.shape[4]))

    X_colorblind = np.array(X_colorblind)
    X_colorblind = X_colorblind.reshape((X_colorblind.shape[1], X_colorblind.shape[2], X_colorblind.shape[0]*X_colorblind.shape[3], X_colorblind.shape[4]))

    X_shapeblind = np.array(X_shapeblind)
    X_shapeblind = X_shapeblind.reshape((X_shapeblind.shape[1], X_shapeblind.shape[2], X_shapeblind.shape[0]*X_shapeblind.shape[3], X_shapeblind.shape[4]))

    data_labelvecs = np.array(data_labelvecs)
    data_labelvecs = data_labelvecs.reshape((data_labelvecs.shape[1], data_labelvecs.shape[0]*data_labelvecs.shape[2]))

    X = X.astype(np.float16)


    np.savez('data/imageX/imageX_{}.npz'.format(traintest), X.astype(np.float16)/255.)
    np.savez('data/imageX/imageX_colorblind_{}.npz'.format(traintest), X_colorblind.astype(np.float16)/255.)
    np.savez('data/imageX/imageX_shapeblind_{}.npz'.format(traintest), X_shapeblind.astype(np.float16)/255.)
    np.save('data/labelvecs/labelvecs_{}.npy'.format(traintest), data_labelvecs)