import os

from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset


def interpret_n_workers(n_workers: int, world_size: int) -> int:
    if n_workers < 0:
        if hasattr(os, "sched_getaffinity"):
            try:
                n_workers = len(os.sched_getaffinity(0))
            except:
                pass

    if n_workers < 0:
        n_workers = os.cpu_count()
        n_workers = max(0, n_workers) if n_workers is not None else 0

    n_workers = max(0, n_workers // world_size)
    return n_workers


def build(ds: Dataset, batch_size: int, n_workers: int, world_size: int) -> DataLoader:
    """
    Returns a Dataloader that uses our custom jagged_first function
    """
    n_workers = interpret_n_workers(n_workers, world_size=world_size)
    dl = DataLoader(
        ds,
        collate_fn=ds.collate_fn,
        batch_size=min(batch_size, max(len(ds), 1)),
        num_workers=n_workers,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        # incompatible with utils.training.handle_oom():  OSError: Too many open files
        # persistent_workers=n_workers > 0,
        persistent_workers=False,
    )
    return dl


def build_from_config(fabric, ds: Dataset, cfg: DictConfig) -> DataLoader:
    return build(
        ds=ds,
        batch_size=cfg.batch_size,
        n_workers=cfg.n_workers,
        world_size=fabric.world_size,
    )
