import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import atari_utils
import numpy as np
import tensorflow as tf

# gsutil -m cp -r gs://rl_unplugged/atari/{GAME}/run_1* .

def generate_atari(env_id, run, verbose, full):
    if full:
        shards = list(range(100))
    else:
        shards = [0]

    print(env_id, run)

    data = atari_utils.atari_dataset('/datadrive/data', env_id, run, shards, 
                                            repeat=False, include_idx=False)

    # create 1% dataset from hash ids
    data = data.filter(lambda x: x.data.extras['episode_id'] < int(1.9e17))

    # check stats
    if verbose:
        it_data = data.batch(500)
        it_data = it_data.prefetch(tf.data.experimental.AUTOTUNE)
        it_data = it_data.as_numpy_iterator()
        print('Filtering...')
        returns = []
        ids = []
        for d in it_data:
            if len(ids) % 10000 == 0:
                print(len(ids))
            extras = d.data.extras
            returns.extend(list(extras['return']))
            ids.extend(list(extras['episode_id']))

        print('mean return: ', np.mean(returns))
        print('data length: ', len(returns))
        print('num episodes: ', len(set(ids)))

    # select exactly 1% of data
    data = data.take(500000)
    data = data.prefetch(tf.data.experimental.AUTOTUNE)

    # save data
    print('Saving...')
    tf.data.experimental.save(data, '/datadrive/data/' + env_id + 
                                    '/run_' + str(run) + '_1percent',
                                    compression='GZIP')

if __name__ == '__main__':
    env_id = 'Qbert'
    verbose = False
    full = True
    
    for run in [3]:
        generate_atari(env_id, run, verbose, full)

