"""Deterministic input pipeline."""

import abc
import dataclasses
from typing import Tuple

from clu import deterministic_data
from clu import preprocess_spec
import jax
import ml_collections
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from ebm_obj.datasets import clevr
from ebm_obj.datasets import clevr_with_masks
from ebm_obj.datasets import tetrominoes_with_masks
from ebm_obj.datasets import dataset_registry


def create_datasets(
  config: ml_collections.ConfigDict, data_rng
) -> Tuple[int, tf.data.Dataset, tf.data.Dataset]:
  """Create datasets for training and evaluation.

  For the same data_rng and config this will return the same datasets. The
  datasets only contain stateless operations. See go/deterministic-training to
  learn how this helps with reproducible training.

  Args:
    config: Configuration to use.
    data_rng: PRNGKey for seeding operations in the training dataset.

  Returns:
    A tuple with the total number of training batches info, the training dataset
    and the evaluation dataset.
  """
  return dataset_registry.DatasetRegistry().lookup(config.data.tfds_name)(
    config.data, data_rng
  )
