import logging

import torch
from transformers.data.data_collator import *

logger = logging.getLogger(__name__)

@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_source_length: Optional[int] = None
    max_target_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"
    add_task_name: bool = False
    add_dataset_name: bool = False
    add_instruction_replay: bool = True
    common_dataset_name: str = None
    text_only: bool = False
    num_examples: int = 0
    input_record_file: str = None

    def __call__(self, batch):
        return self.decoder_call(batch)

    def decoder_call(self, batch):
        self.tokenizer.padding_side = 'left'
        input_ids= []
        attention_mask= []
        input_ids_wo_label = []
        labels= []

        for instance in batch:
            label = instance['Instance']['label']
            instruction = instance['Instance']['instruction']
            # add bos and eos
            task_input = instruction
            label = label + self.tokenizer.eos_token

            tokenized_input = self.tokenizer(task_input, add_special_tokens=False)["input_ids"]
            if len(tokenized_input)>self.max_source_length:
                tokenized_input=tokenized_input[:self.max_source_length]

            tokenized_label = self.tokenizer(label, add_special_tokens=False)["input_ids"]
            if len(tokenized_label)>self.max_target_length:
                tokenized_label=tokenized_label[:self.max_target_length]

            # (input) for inference, (input + label) for training
            input_ids.append(tokenized_input+tokenized_label)
            input_ids_wo_label.append(tokenized_input)
            labels.append([self.label_pad_token_id]*len(tokenized_input)+tokenized_label)
        
        inputs_length=[len(i) for i in input_ids]
        inputs_length_wo_label = [len(i) for i in input_ids_wo_label]

        max_length=max(inputs_length)
        max_length_wo_label=max(inputs_length_wo_label)
        for i,(l,l_wo) in enumerate(zip(inputs_length, inputs_length_wo_label)):
            input_ids[i]=[self.tokenizer.pad_token_id]*(max_length-l) + input_ids[i]
            labels[i]=[self.label_pad_token_id]*(max_length-l) + labels[i]
            input_ids_wo_label[i] = [self.tokenizer.pad_token_id]*(max_length_wo_label-l_wo) + input_ids_wo_label[i]
            attention_mask.append([0]*(max_length-l) + [1]*l)

        input_ids=torch.tensor(input_ids)
        attention_mask=torch.tensor(attention_mask)
        labels=torch.tensor(labels)
        input_ids_wo_label=torch.tensor(input_ids_wo_label)
        model_inputs={
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
        return model_inputs