# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from tqdm import tqdm
from itertools import chain

from torch.utils.data import Dataset


class NoPartialSentenceConcatDataset(Dataset):
    def __init__(self, dataset, chunk_size=4096, pad_token_id=0):
        self.dataset = dataset
        self.chunk_size = chunk_size

        self.samples = []

        prev_buffer = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
            }

        padding = {
            "input_ids": [pad_token_id] * chunk_size,
            "attention_mask": [0] * chunk_size,
            "labels": [-100] * chunk_size,
        }

        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
            sample: dict = sample
            
            new_buffer = {k: v + sample[k] for k,v in prev_buffer.items()}

            if len(next(iter(new_buffer.values()))) > self.chunk_size:
                # if a single sample is longer than the chunk size, we truncate and include it
                if len(sample["input_ids"]) > self.chunk_size:
                    self.samples.append({k: v[:self.chunk_size] for k,v in sample.items()})
                    continue  # use the same prev_buffer for the next iteration again
                
                # otherwise, we don't want to make a chunk with a partial sentence
                # so we use prev_buffer if the new_buffer is longer than the chunk size
                if len(prev_buffer["input_ids"]) > 0:
                    padded_buffer = {k: v + padding[k] for k,v in prev_buffer.items()}
                    self.samples.append({k: v[:self.chunk_size] for k,v in padded_buffer.items()})
                
                # reset the buffer
                prev_buffer = sample
            else:
                prev_buffer = new_buffer

    def __getitem__(self, idx):
        return self.samples[idx]

    def __len__(self):
        return len(self.samples)
