import h5py
import numpy as np
import pathlib
from tqdm import tqdm
from collections import defaultdict
import os
import pickle
import multiprocessing as mp

def load_file(filename):
  with filename.open('rb') as f:
    episode = np.load(f)
    ep = {k: v for k,v in episode.items()}
  return ep

def load_episodes(directory, capacity=None, minlen=1):
  # The returned directory from filenames to episodes is guaranteed to be in
  # temporally sorted order.
  filenames = sorted(directory.glob('*.npz')) # earliest first.
  if capacity:
    num_steps = 0
    num_episodes = 0
    for filename in reversed(filenames): # get latest episodes.
      length = int(str(filename).split('-')[-1][:-4])
      num_steps += length
      num_episodes += 1
      if num_steps >= capacity:
        break
    filenames = filenames[-num_episodes:]
  episodes = defaultdict(list)

  jobs = [(f) for f in filenames]
  
  with mp.Pool() as pool:
    all_ep_dict = []
    for ep in tqdm(pool.map(load_file, jobs), 'load_episodes'):
      all_ep_dict.append(ep)

    for ep_dict in all_ep_dict:
      for k, v in ep_dict.items():
        episodes[k].append(v)

  for k,v in episodes.items():
    episodes[k] = np.asarray(v)
  return episodes, filenames

DIR_PATH = "/home/anonymous/logdir/stackwalls_p2e_s0/train_episodes"
FILE_NAME_PATH = os.path.join(DIR_PATH, "filenames.pkl")
HDF5_PATH = os.path.join(DIR_PATH, "data.hdf5")

directory = pathlib.Path(DIR_PATH).expanduser()
data, filenames = load_episodes(directory, capacity=1e6)
dataset = h5py.File(HDF5_PATH, 'w')
file_strings = [str(f) for f in filenames]
with open(FILE_NAME_PATH, "wb") as f:
  pickle.dump(file_strings, f)

for k in data:
  dataset.create_dataset(k, data=data[k], compression='gzip')