import os
import numpy as np
import h5py

if __name__ == '__main__':
    os.chdir(r'C:\Users\a50029551\Documents\Code\mbrl_dp_paper\Datasets D4RL')
    file_name = 'walker_mixed.hdf5'
    f = h5py.File(file_name, 'r+')

    timeouts = f['timeouts'][()]
    terms = f['terminals'][()]
    argtimeouts = np.argwhere(timeouts == True).flatten()
    argterms = np.argwhere(terms == True).flatten()

    argall = np.concatenate([argtimeouts, argterms])
    argall.sort()
    ep_len = argall[1:] - argall[:-1]
    ep_len = np.concatenate([np.array([argall[0] + 1]), ep_len])

    nb_eps = ep_len.shape[0]
    pad_mask = np.zeros(1000 * nb_eps)
    for ep_idx in range(ep_len.shape[0]):
        pad_mask[ep_idx * 1000: ep_idx * 1000 + ep_len[ep_idx]] = 0
        pad_mask[ep_idx * 1000 + ep_len[ep_idx]: (ep_idx + 1) * 1000] = 1
    pad_mask = pad_mask.astype(bool)

    #output_name = file_name.split('.')[0] + '_padmask'
    #np.save(output_name, pad_mask)

    obs = f['observations'][()]
    next_obs = f['next_observations'][()]
    actions = f['actions'][()]
    rewards = f['rewards'][()]
    terminals = f['terminals'][()]

    # Create dataset
    new_dataset = h5py.File("walker_mixed_padded_low_reward.hdf5", "w")
    new_dataset.create_dataset('observations', data=obs)
    new_dataset.create_dataset('actions', data=actions)
    new_dataset.create_dataset('rewards', data=rewards)
    new_dataset.create_dataset('terminals', data=terminals)

    end = True
