import argparse
import os
import pathlib

import h5py
import numpy as np
from tqdm import tqdm

from bpref_v2.data.qlearning_dataset import qlearning_factorworld_dataset


def load_episodes(directory, h5_file, capacity=int(5e5)):
    # The returned directory from filenames to episodes is guaranteed to be in
    # temporally sorted order.
    filenames = sorted(directory.glob("*.npz"))
    total_timestep = 0
    for i, filename in tqdm(enumerate(filenames), total=len(filenames)):
        try:
            with filename.open("rb") as f:
                episode = np.load(f, allow_pickle=True)
                episode = {k: episode[k] for k in episode.keys()}
                len_episode = np.asarray(episode["actions"]).shape[0]
            if "terminals" not in episode:
                episode["terminals"] = np.asarray(episode["rewards"] == 0.0, dtype=np.int32)
        except Exception as e:
            raise f"Could not load episode {str(filename)}: {e}"
        if i == 0:
            for key, val in episode.items():
                try:
                    h5_file.create_dataset(
                        key, (capacity, *val[0].shape), dtype=val[0].dtype, chunks=(16, *val[0].shape)
                    )
                except Exception as e:
                    print(f"{key} already exists: {e}")
                    continue
        for key, val in episode.items():
            h5_file[key][total_timestep : total_timestep + len_episode] = val
        total_timestep += len_episode

    for key in h5_file.keys():
        v_shape = h5_file[key].shape[1:]
        h5_file[key].resize((total_timestep, *v_shape))


def ds_to_hdf5(ds: dict, h5_file):
    data_length = ds["rewards"].shape[0]
    for key, val in ds.items():
        h5_file.create_dataset(key, (data_length, *val[0].shape), data=val, chunks=(16, *val[0].shape))


def main():
    # Include argument parser
    parser = argparse.ArgumentParser(description="Convert npz files to hdf5.")
    parser.add_argument("--split", type=str, choices=["train", "val"], default="train")
    parser.add_argument("--input_dir", type=str, required=True, help="Path to input files")
    args = parser.parse_args()

    out_dir = pathlib.Path(args.input_dir).expanduser()
    out_dir.mkdir(parents=True, exist_ok=True)

    try:
        os.remove(out_dir / "data.hdf5")
    except OSError as e:
        print(f"error occurred. : {e}")
        pass
    shard_file = h5py.File(out_dir / "data.hdf5", "a")
    # load_episodes(pathlib.Path(args.input_dir), shard_file)
    ds = qlearning_factorworld_dataset(args.input_dir)
    ds_to_hdf5(ds, shard_file)


if __name__ == "__main__":
    main()
