"""Dataset registry."""

from ebm_obj.lib import registry


class DatasetRegistry(registry.Registry):
  """Registry to hold callables to create Dataset."""

  _shared_state = {}

  def __init__(self):
    super(DatasetRegistry, self).__init__(self._shared_state)


def register(name: str):
  """Register a Dataset class.

  Args:
    name: Name of Dataset to register under.

  Returns:
    A wrapper that registers a Dataset class.

  Example:

    To define and register a Dataset.

    @dataset_registry.register("awesome")
    class AwesomeDataset(Dataset):
      pass
  """

  def wrapper(dataset):
    DatasetRegistry().register(name, dataset)
    return dataset

  return wrapper