from collections import defaultdict
from typing import Dict, List
from pathlib import Path
import numpy as np
import h5py


class DatasetWriter:

    def __init__(self, keys=None):
        self.keys = keys
        self.data = self._reset_data()
        self._num_samples = 0

    def _reset_data(self):
        self._num_samples = 0
        data = defaultdict(list)
        return data

    def __len__(self):
        return self._num_samples

    def append_data(self, np_data_dict: Dict[str, np.ndarray]):
        self._num_samples += 1
        if self.keys is None:
            self.keys = list(np_data_dict.keys())
        for key in self.keys:
            self.data[key].append(np_data_dict[key])

    def extend_data(self, np_data_dict: Dict[str, List[np.ndarray]]):
        if self.keys is None:
            self.keys = list(np_data_dict.keys())
        self._num_samples += len(np_data_dict[self.keys[0]])
        for key in self.keys:
            self.data[key].extend(np_data_dict[key])

    def merge(self, writer):
        self._num_samples += len(writer)
        for key in self.keys:
            self.data[key].extend(writer.data[key])

    def write_dataset(self, path, max_size=None, compression="gzip"):
        np_data = {}
        for key in self.data:
            if key == "dones":
                dtype = np.bool_
            elif key == "infos/goal_id":
                dtype = np.int64
            elif "image" in key:
                dtype = np.uint8
            else:
                dtype = np.float32
            data = np.array(self.data[key], dtype=dtype)
            if max_size is not None:
                data = data[:max_size]
            np_data[key] = data

        Path(path).parent.mkdir(parents=True, exist_ok=True)
        dataset = h5py.File(path, "w")
        for k in np_data:
            dataset.create_dataset(k, data=np_data[k])
        dataset.close()
