import numpy as np
import torch.nn as nn
from transformers import T5Tokenizer

class GlueUtils:

    columns = [
        'input_ids',
        'attention_mask',
        'token_type_ids',
        'position_ids',
        'head_mask',
        'inputs_embeds',
        'output_attentions',
        'output_hidden_states',
        'return_dict',
        'labels',
    ]

    glue_task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }

    glue_task_to_labels = {
        "cola": {0: "unacceptable", 1: "acceptable"},
        "mnli": {0: "entailment", 1: "neutral", 2: "contradiction"},
        "mrpc": {0: "not_equivalent", 1: "equivalent"},
        "qnli": {0: "entailment", 1: "not_entailment"},
        "qqp":  {0: "not_duplicate", 1: "duplicate"},
        "rte":  {0: "entailment", 1: "not_entailment"},
        "sst2": {0: "negative", 1: "positive"},
        "stsb": None,
        "wnli": {0: "not_entailment", 1: "entailment"},
    }

    def init_glue_preprocess_function(self, args, tokenizer: T5Tokenizer):
        # GELU
        task_prefix = args.task_name
        sentence1_key, sentence2_key = self.glue_task_to_keys[args.task_name]
        def preprocess_function(examples):
            # Tokenize the texts
            if sentence2_key is None:
                texts = [" ".join((task_prefix, sentence1_key + ":", x)) 
                    for x in examples[sentence1_key]]
            else:
                texts = [" ".join((task_prefix, sentence1_key + ":", x, sentence2_key + ":", y))
                    for x, y in zip(examples[sentence1_key], examples[sentence2_key])]
            result = tokenizer(texts)
            if args.task_name != "stsb":
                label_dict = self.glue_task_to_labels[args.task_name]
                labels = [label_dict[x] for x in examples["label"]]
            else:
                labels = ["{:.1f}".format(float(x)) for x in examples["label"]]
            labels = tokenizer(labels).input_ids
            result["labels"] = labels
            return result
        return preprocess_function
