import torch.distributed
from . import collator
from . import binarizer
from . import data_utils
from . import dataset

from .binarizer import Writer, Binarizer, DatasetBinarizer
from .collator import DataCollator
from .dataset import LMCorpusMemmapDataset, LMMemmapDataset

import torch
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from CoLM.option import TrainArg


def build_dataloader(
    dataset: LMCorpusMemmapDataset,
    args: TrainArg,
    batch_size: int,
    eos_token: int,
) -> DataLoader:
    collator = DataCollator(
        eos_token=eos_token,
        append_bos=args.append_bos,
    )
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        sampler = DistributedSampler(dataset)
    else:
        sampler = RandomSampler(dataset)

    data_loader = DataLoader(
        dataset,
        sampler=sampler,
        collate_fn=collator,
        batch_size=batch_size,
        num_workers=args.num_worker,
        pin_memory=True
    )
    return data_loader

