import torch

IGNORE_INDEX = -100
class PrunedDataCollator:
    def __init__(
        self, 
        tokenizer,
        max_length,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length 

    def __call__(self, examples):
        '''
        examples contain:
            input_ids,
            labels,
            attention_mask, 
            prompt,
            example,
        '''
        input_ids, labels = tuple([torch.LongTensor(instance[key]) for instance in examples]
                                  for key in ("input_ids", "labels"))
        
        # input_ids = torch.nn.utils.rnn.pad_sequence(
        #     input_ids,
        #     batch_first=True,
        #     padding_value=self.tokenizer.pad_token_id)
        # labels = torch.nn.utils.rnn.pad_sequence(labels,
        #                                          batch_first=True,
        #                                          padding_value=IGNORE_INDEX)
        
        # input_ids = input_ids[:, :self.tokenizer.model_max_length]
        # labels = labels[:, :self.tokenizer.model_max_length]
        
        input_ids, labels = torch.vstack(input_ids), torch.vstack(labels)
        
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ge(0),
        )