"""Utilities for datasets."""

from typing import Dict, List, Tuple

import jax

def get_batch_dims(global_batch_size: int) -> List[int]:
  """Gets the first two axis sizes for data batches.

  Args:
    global_batch_size: Integer, the global batch size (across all devices).

  Returns:
    List of batch dimensions

  Raises:
    ValueError if the requested dimensions don't make sense with the
      number of devices.
  """
  num_local_devices = jax.local_device_count()
  if global_batch_size % jax.host_count() != 0:
    raise ValueError(f"Global batch size {global_batch_size} not evenly "
                     f"divisble with {jax.host_count()}.")
  per_host_batch_size = global_batch_size // jax.host_count()
  if per_host_batch_size % num_local_devices != 0:
    raise ValueError(f"Global batch size {global_batch_size} not evenly "
                     f"divisible with {jax.host_count()} hosts with a per host "
                     f"batch size of {per_host_batch_size} and "
                     f"{num_local_devices} local devices. ")
  return [num_local_devices, per_host_batch_size // num_local_devices]
