from typing import Iterator
from jax import numpy as jnp
from jax._src.random import PRNGKey
from omegaconf import DictConfig
from typing import Tuple

from fair_dp_sgd.data.tabular import tabular_data_stream
from fair_dp_sgd.data.mnist import mnist_data_stream
from fair_dp_sgd.data.celeba import celeba_data_stream

DataStream = Iterator[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]


def get_data_stream(
    cfg: DictConfig, rng: PRNGKey, seed
) -> Tuple[DataStream, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
    if cfg.dataset.name == "mnist":
        return mnist_data_stream(cfg, rng)
    elif cfg.dataset.name in ["adult", "retired-adult", 
                              "credit-card", "chit-default-small", 
                              "parkinsons", "folktables", 
                              "utkface-age-encoded",
                              "heart",
                              "civilcomments_embeddings",
                              "colored_mnist"]:
        return tabular_data_stream(cfg, rng, seed)
    elif cfg.dataset.name == "celeba":
        return celeba_data_stream(cfg, rng)
    else:
        raise ValueError(f"Unknown dataset: {cfg.dataset.name}")
