import numpy as np
import pickle
from time import time
from scripts.vae import *
import os

n_agent = 128
for traintest in ['train', 'test']:
    vecX = np.load('data/vecX/vecX_{}.npy'.format(traintest))
    for blind in ['color', 'shape']:
        imageX = np.load('data/imageX/imageX_{}blind_{}.npz'.format(blind, traintest))['arr_0']
        imageX = torch.from_numpy(imageX)

        imageX = flatten(imageX)
        imageX = imageX.float()

        vae = VAE()
        load_model(vae, blind, 990)
        vae.zero_grad()
        vae.eval()

        for agent_id in range(n_agent):
            if os.path.exists('data/X/{}blind/X_{}_{}.npy'.format(blind, agent_id, traintest)): continue
            with torch.no_grad():
                encoded = vae.encoder(imageX)
                mu, logvar = torch.chunk(encoded, 2, dim=1)
                std = logvar.mul(0.5).exp_()
                esp = to_var(torch.randn(*mu.size()))
                encoded = mu + std * esp.cpu()

            X = np.concatenate((encoded, vecX), axis=1)
            n_train = int(len(X) * 0.9)
            np.save('data/X/{}blind/X_{}_{}.npy'.format(blind, agent_id, traintest), X)
            if traintest == 'train':
                np.save('data/X/{}blind/X_{}_val.npy'.format(blind, agent_id), X[n_train:])

            print(traintest, blind, agent_id)