


from tqdm import tqdm 
from itertools import chain 

from torch .utils .data import Dataset 


class ConcatDataset (Dataset ):
    def __init__ (self ,dataset ,chunk_size =4096 ):
        self .dataset =dataset 
        self .chunk_size =chunk_size 

        self .samples =[]

        buffer ={
        "input_ids":[],
        "attention_mask":[],
        "labels":[],
        }

        for sample in tqdm (self .dataset ,desc ="Preprocessing dataset",dynamic_ncols =True ):
            buffer ={k :v +sample [k ]for k ,v in buffer .items ()}

            while len (next (iter (buffer .values ())))>self .chunk_size :
                self .samples .append ({k :v [:self .chunk_size ]for k ,v in buffer .items ()})
                buffer ={k :v [self .chunk_size :]for k ,v in buffer .items ()}

    def __getitem__ (self ,idx ):
        return self .samples [idx ]

    def __len__ (self ):
        return len (self .samples )
