import torch
from torch.utils.data import DataLoader
import random

class InputCopyingDataLoader(DataLoader):
    def __init__(self, batch_size, num_batch, vocab_size, bos_id, copy_id, start_id, max_seq_length):
        self.batch_size = batch_size
        self.num_batch = num_batch
        self.vocab_size = vocab_size
        self.bos_id = bos_id
        self.copy_id = copy_id
        self.start_id = start_id
        self.max_seq_length = max_seq_length
        self.min_seq_length = max(max_seq_length // 2, max_seq_length - 10)
    
    def __iter__(self):
        for _ in range(self.num_batch):
            # determine the sequence length
            seq_length = random.randint(self.min_seq_length, self.max_seq_length)

            # generate bos tokens and copy tokens
            bos_tokens = torch.full((self.batch_size, 1), self.bos_id)
            copy_tokens = torch.full((self.batch_size, 1), self.copy_id)

            # generate the input sequence
            input_seq = torch.randint(self.start_id, self.vocab_size, 
                                      (self.batch_size, seq_length))
            
            # yield the input sequence and the target sequence
            yield {"input_ids": torch.cat([bos_tokens, input_seq, copy_tokens, input_seq[:, :-1]], dim=1),
                   "labels": input_seq, "mask": None}
    
    def __len__(self):
        return self.num_batch
    
if __name__ == "__main__":
    dataloader = InputCopyingDataLoader(batch_size=2, num_batch=3, start_id=2, vocab_size=8, bos_id=0, copy_id=1, max_seq_length=20)
    for batch in dataloader:
        print("input_ids:")
        print(batch["input_ids"])
        print("labels:")
        print(batch["labels"])