from transformers import T5TokenizerFast
from typing import List, Tuple, Union, Optional, Any

class SquadUtils:

    def __init__(self,
        data_args: Any,
        question_column: str,
        context_column: str,
        answer_column: str,
        tokenizer: T5TokenizerFast,
        max_seq_length: Optional[int],
        max_answer_length: Optional[int],
        padding: Optional[Union[str, bool]],
    ) -> None:
        self.data_args = data_args
        self.question_column = question_column
        self.context_column = context_column
        self.answer_column = answer_column
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.max_answer_length = max_answer_length
        self.padding = padding

    def preprocess_squad_batch(self,
        examples,
        question_column: str,
        context_column: str,
        answer_column: str,
    ) -> Tuple[List[str], List[str]]:
        questions = examples[question_column]
        contexts = examples[context_column]
        answers = examples[answer_column]

        def generate_input(_question, _context):
            return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()])

        inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
        targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
        return inputs, targets

    def preprocess_function(self, examples):
        inputs, targets = self.preprocess_squad_batch(examples, self.question_column, self.context_column, self.answer_column)

        model_inputs = self.tokenizer(inputs, max_length=self.max_seq_length, padding=self.padding, truncation=True)
        # Tokenize targets with text_target=...
        labels = self.tokenizer(text_target=targets, max_length=self.max_answer_length, padding=self.padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if self.padding == "max_length" and self.data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    # Validation preprocessing
    def preprocess_validation_function(self, examples):
        inputs, targets = self.preprocess_squad_batch(examples, self.question_column, self.context_column, self.answer_column)

        model_inputs = self.tokenizer(
            inputs,
            max_length=self.max_seq_length,
            padding=self.padding,
            truncation=True,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
        )
        # Tokenize targets with the `text_target` keyword argument
        labels = self.tokenizer(text_target=targets, max_length=self.max_answer_length, padding=self.padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if self.padding == "max_length" and self.data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = model_inputs.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        model_inputs["example_id"] = []
        # Augment the overflowing tokens to the labels
        labels_out = []

        for i in range(len(model_inputs["input_ids"])):
            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            model_inputs["example_id"].append(examples["id"][sample_index])
            labels_out.append(labels["input_ids"][sample_index])

        model_inputs["labels"] = labels_out
        return model_inputs
