from anGPT.data.buffered_sequencer import BufferedSequencer
from anGPT.data.mmap_dataset import MMapDataset
from anGPT.data.prefetch_collator_old import PrefetchCollator, BufferedPrefetchCollator
from anGPT.data.tokenizer import get_tokenizer


def get_data_pipeline(cfg, num_processes, process_index, device):
    tokenizer = get_tokenizer(**cfg.tokenizer)

    mmap_dl = cfg.train_mmap_dl
    text_dataset = MMapDataset(gpu_world_size=num_processes, gpu_worker=process_index, **mmap_dl)

    sequenced_dataset = BufferedSequencer(text_dataset,
                                          sequence_length=cfg.data_base.sequence_length + 1,
                                          min_sequence_length=cfg.data_base.min_sequence_length,
                                          drop_last=True)

    train_dl = BufferedPrefetchCollator(sequenced_dataset, device, batch_size=cfg.data_base.batch_size,
                                        buffer_size=2, ignore_index=cfg.data_base.ignore_index)

    mmap_dl = cfg.valid_mmap_dl
    val_text_dataset = MMapDataset(gpu_world_size=num_processes, gpu_worker=process_index, **mmap_dl)

    val_sequenced_dataset = BufferedSequencer(val_text_dataset,
                                              sequence_length=cfg.data_base.sequence_length + 1,
                                              min_sequence_length=0,
                                              drop_last=False)

    valid_dl = PrefetchCollator(val_sequenced_dataset, device, batch_size=cfg.data_base.batch_size,
                                buffer_size=2, ignore_index=cfg.data_base.ignore_index)

    return train_dl, valid_dl, tokenizer
