"""encoded_ae_h5py_dataset dataset."""

import tensorflow as tf
import tensorflow_datasets as tfds
import h5py
import numpy as np

_DESCRIPTION = """
Description is **formatted** as markdown.

It should also contain any processing which has been applied (if any),
(e.g. corrupted example skipped, images cropped,...):
"""

_CITATION = """
"""

class EncodedAEH5pyConfig(tfds.core.BuilderConfig):
    def __init__(self, *, path, min_length, latent_shape, **kwargs):
        super().__init__(version=tfds.core.Version('1.0.0'), **kwargs)
        self.path = path
        self.min_length = min_length
        self.latent_shape = latent_shape


class EncodedAEH5pyDataset(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for encoded_h5py_dataset dataset."""

  BUILDER_CONFIGS = [
    EncodedAEH5pyConfig(
      name='dl_maze',
      path='encoded_ae_dl_maze.hdf5',
      min_length=None,
      latent_shape=(16, 16, 8)
    ),
    EncodedAEH5pyConfig(
      name='minerl_marsh_v2',
      path='encoded_ae_minerl_marsh_v2.hdf5',
      min_length=None,
      latent_shape=(16, 16, 8)
    ),
    EncodedAEH5pyConfig(
      name='habitat_l300',
      path='encoded_ae_habitat_l300.hdf5',
      min_length=None,
      latent_shape=(16, 16, 8)
    ),
  ]

  MANUAL_DOWNLOAD_INSTRUCTIONS = """
  Place the `*.hdf5` file in the `manual_dir/`.
  """

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {
      '1.0.0': 'Initial release.',
  }

  def _info(self) -> tfds.core.DatasetInfo:
    """Returns the dataset metadata."""
    latent_shape = self.builder_config.latent_shape
    return tfds.core.DatasetInfo(
        builder=self,
        features=tfds.features.FeaturesDict({
            'video': tfds.features.Tensor(shape=(None, *latent_shape), dtype=tf.float32),
            'actions': tfds.features.Tensor(shape=(None,), dtype=tf.int32),
        }),
    )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
    """Returns SplitGenerators."""
    path = dl_manager.manual_dir / self.builder_config.path

    return {
        'train': self._generate_examples(path, 'train'),
        'test': self._generate_examples(path, 'test'),
    }

  def _generate_examples(self, path, split):
    """Yields examples."""
    data = h5py.File(path, 'r')
    images = data[f'{split}_data']
    if f'{split}_actions' in data and 'bair' not in self.builder_config.path:
        actions = data[f'{split}_actions'][:]
    else:
        print('Did not find actions... Generating dummy actions')
        actions = np.zeros((images.shape[0],), dtype=np.int32)
    idxs = data[f'{split}_idx'][:]

    for i in range(len(idxs)):
      start = idxs[i]
      end = idxs[i + 1] if i < len(idxs) - 1 else len(images)
      video = images[start:end]
      action = actions[start:end].astype(np.int32)
      if self.builder_config.min_length is not None and video.shape[0] < self.builder_config.min_length:
          continue

      yield i, {
        'video': video,
        'actions': action,
      }
