from . import dataset
from . import collator
from . import binarizer

from .dataset import MultipleCorpusMemmapDataset
from .collator import DataCollator, SparseDataCollator
from .binarizer import get_tokenizer

from torch.utils.data import (
    DataLoader,
    DistributedSampler,
)
from LLMProxy.option import TrainArg
from typing import List


def build_dataset(splits: List[str], args: TrainArg):
    return MultipleCorpusMemmapDataset(
        path=args.data_dir,
        tokenizer=args.tokenizer,
        splits=splits,
        block_size=args.seq_len,
    )


def build_train_and_valid_datasets(args: TrainArg):
    train_dataset = build_dataset(args.train_split, args)
    valid_dataset = build_dataset(args.valid_split, args)
    return train_dataset, valid_dataset


def build_dataloader(
    dataset,
    args: TrainArg,
    batch_size: int,
    eos_token: int,
):
    # collator = DataCollator(eos_token=eos_token, append_bos=args.append_bos)
    collator = SparseDataCollator(eos_token=eos_token, append_bos=args.append_bos, sample_len=args.sample_len)

    sampler = DistributedSampler(dataset)

    data_loader = DataLoader(
        dataset,
        sampler=sampler,
        collate_fn=collator,
        batch_size=batch_size,
        num_workers=args.num_worker,
        pin_memory=True,
    )
    return data_loader
