import os
import tensorflow as tf
import itertools


def subset_data(path, size, new_path):
    d = tf.data.experimental.load(path)
    if len(d) > size:
        d = d.take(size)
        tf.data.experimental.save(d, new_path)
    else:
        print('too small: ', len(d))
    return


if __name__ == '__main__':

    grid = {
        "env_id": [  'cartpole'], # 'catch
        "env_noise": [0.0],
        "type": ['uni'], # , 'med', 'exp'
        "seed": [0,5,10,15],
        "data_size": [100000] # 2000, 5000 # 20000, 50000, 100000
    }

    grid_setups = list(
        dict(zip(grid.keys(), values)) for values in itertools.product(*grid.values())
    )

    for i in range(len(grid_setups)):
        g = grid_setups[i]
        print('------------------')
        print(g)
        print('------------------')

        path = '/datadrive/data/' + g['env_id'] + '_' + str(g['env_noise']) + '/' + \
                    g['type'] + '_' + str(g['seed'])
        new_path = '/datadrive/data/' + g['env_id'] + '_' + str(g['env_noise']) + '/' + \
                    g['type'] + '_' + str(g['seed']) + '_' + str(g['data_size'] // 1000) + 'k'
        size = g['data_size']

        subset_data(path, size, new_path)