"""
Functionalities for pre-processing and post-processing of GLUE datasets.
"""
from typing import Dict, List, Union, Any, Optional
import re
import torch
from datasets.arrow_dataset import Dataset
from promptsource.templates import DatasetTemplates
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

TASK_MAPPING_DATASET_ARGUMENTS = {

    "mrpc": ["glue", "mrpc"],
    "squad": ["squad"],
    "squad_v2": ["squad_v2"],
    "winogrande": ["winogrande", "winogrande_l"],
    "hellaswag": ["hellaswag"],
    "piqa": ["piqa"],
    "mmlu": ["mmlu"],
    "arc_easy": ["ai2_arc", "ARC-Easy"],
    "arc_challenge": ["ai2_arc", "ARC-Challenge"],
}

TASK_MAPPING_PROMPT_KEY = {
    "multirc": "found_this_answer",
    # "openbookqa": "pick_using_id",
    "openbookqa": "pick_answer_with_options",
    "sst2": "positive negative after",
    "mrpc": "paraphrase",
    "rte": "does the claim… follow the fact…",
    "squad": "answer_given_context_and_question",
    "squad_v2": "Questions with Context",
    # "copa": "cause_effect",
    "copa": "best_option",
    "qqp": "answer",
    "cola": "Make sense yes no",
    "stsb": "examples",
    "qnli": "based only on",
    "winogrande": "True or False",
    "wikiqa": "Decide_good_answer",
    "hotpotqa": "generate_answer_affirmative",
    "mnli": "GPT-3 style",
    "hellaswag": "Appropriate continuation - Yes or No",
}

def build_index_for_dataset(
        dataset: Dataset
):
    """ add a key 'idx' to each example in dataset """
    if "idx" in dataset.column_names:
        return dataset
    ids = list(range(len(dataset)))
    dataset = dataset.add_column("idx", ids)
    return dataset

class WinoGrandePreProcessor:
    def __init__(self, benchmark: str, is_train=True):
        self.benchmark = benchmark
        self.is_train = is_train

    def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            if self.is_train:
                ret = {'query': [], "idx": []}
            else:
                ret = {"cont_idx": [], "label_id": [], "idx": [], "continuation": [], 'ctx': []}
            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                label_id = int(sample["answer"]) - 1
                pronoun_loc = sample["sentence"].index("_")
                ans = sample["sentence"][:pronoun_loc] + sample["option" + sample['answer']]
                ctxs = [ sample["sentence"][:pronoun_loc] + sample["option1"],
                        sample["sentence"][:pronoun_loc] + sample["option2"] ]
                continuation = " " + sample["sentence"][pronoun_loc+1:].strip()
                
                if self.is_train:
                    ret['query'].append(ans + continuation[:-1])
                    ret['idx'].append(sample['idx'])
                else:
                    for cont_id, ctx in enumerate(ctxs):
                        ret['cont_idx'].append(cont_id)
                        ret['label_id'].append(label_id)
                        ret['idx'].append(sample['idx'])
                        ret['ctx'].append(ctx)
                        ret['continuation'].append(continuation)

        else:
            raise NotImplementedError("WinograndeDataPreProcessor only supports batched=True")
        return ret

class HellaSwagPreProcessor:
    def __init__(self, benchmark: str):
        self.benchmark = benchmark

    def preprocess(cls, text):
        text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        text = text.replace(" [title]", ". ")
        text = re.sub("\\[.*?\\]", "", text)
        text = text.replace("  ", " ")

        return text

    def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            ret = {'query': [], "idx": []}

            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                text = self.preprocess(sample["activity_label"] + ": " + sample["ctx_a"] + " " + sample["ctx_b"].capitalize())
                ending = " " + self.preprocess(sample["endings"][int(sample["label"])])
                
                ret['query'].append(text + ending)
                ret['idx'].append(sample['idx'])

        else:
            raise NotImplementedError("HellaSwagPreProcessor only supports batched=True")
        return ret
    
class PIQAPreProcessor:
     def __init__(self, benchmark: str):
        self.benchmark = benchmark
     def preprocess(cls, text):
        text = text.strip()
        text = text.replace("  ", " ")

        return text

     def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            ret = {'query': [], "idx": []}
            
            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                answer = [" " + sample["sol1"], " " + sample["sol2"]]
                text = self.preprocess("Question: " + sample["goal"] + "\nAnswer: " + answer[int(sample["label"])])
                
                ret['query'].append(text)
                ret['idx'].append(sample['idx'])

        else:
            raise NotImplementedError("PIQAPreProcessor only supports batched=True")
        return ret
     
class ARCPreProcessor:
     def __init__(self, benchmark: str):
        self.benchmark = benchmark

     def preprocess(cls, text):
        text = text.strip()
        text = text.replace("  ", " ")

        return text

     def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            ret = {'query': [], "idx": []}
            
            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                choices = sample["choices"]
                answer_idx = choices['label'].index(sample["answerKey"])
                text = "Question: " + sample["question"] + "\nAnswer: " + choices["text"][answer_idx]
                
                ret['query'].append(text)
                ret['idx'].append(sample['idx'])

        else:
            raise NotImplementedError("ARCPreProcessor only supports batched=True")
        return ret

class COPAPreProcessor:
     def __init__(self, benchmark: str):
        self.benchmark = benchmark

     def preprocess(cls, text):
        text = text.strip()
        text = text.replace("  ", " ")

        return text
     
     def doc_to_text(self, doc):
        connector = "because" if doc["question"] == "cause" else "therefore"

        # remove the period
        return doc["premise"].strip()[:-1] + " " + connector
     def doc_to_choice(self, doc):
        # add spaces in front of continuation
        def convert_choice(choice):
            return choice[0].lower() + choice[1:]

        return [" " + convert_choice(doc["choice1"]), " " + convert_choice(doc["choice2"])]

     def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            ret = {'query': [], "idx": []}
            
            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                choices = self.doc_to_choice(sample)
                answer_idx = sample["label"]
                text = self.preprocess(self.doc_to_text(sample) + choices[int(answer_idx)])
                
                ret['query'].append(text)
                ret['idx'].append(sample['idx'])

        else:
            raise NotImplementedError("COPAPreProcessor only supports batched=True")
        return ret

class SQUADPreProcessor:
     def __init__(self, benchmark: str):
        self.benchmark = benchmark

     def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            ret = {'query': [], "idx": []}
            
            for i in range(batch_size):
                sample = {k: v[i] for k, v in examples.items()}
                context = sample["context"]
                QA = "\nQuestion: " + sample["question"] + "\nAnswer: " + sample["answers"]["text"][0]
                
                ret['query'].append(context + QA)
                ret['idx'].append(sample['idx'])

        else:
            raise NotImplementedError("SQUADPreProcessor only supports batched=True")
        return ret

class MRPCPreProcessor:
    def __init__(self, benchmark: str, is_train=True):
        self.benchmark = benchmark
        self.is_train = is_train
    
    def preprocess(cls, string: str) -> str:
        string = string.replace(" n't", "n't")
        string = string.replace(" )", ")")
        string = string.replace("( ", "(")
        string = string.replace('" ', '"')
        string = string.replace(' "', '"')

        string = re.sub(r" (['.,])", r"\1", string)

        return string
    
    def doc_to_text(self, doc):
        return (
            "Sentence 1: "
            + self.preprocess(doc["sentence1"])
            + "\nSentence 2: "
            + self.preprocess(doc["sentence2"])
            + "\nQuestion: Do both sentences mean the same thing?"
        )
    def doc_to_label(self, doc):
        # if doc['label'] is True, return index of " yes" which is 0
        if doc["label"]:
            return 0
        else:
            return 1
    

    def __call__(self, examples, batched: Optional[bool] = True) -> Dict[str, List]:
        assert "idx" in examples, "Call build_index_for_dataset() first."
        if isinstance(examples["idx"], list) or batched:
            batch_size = len(examples["label"]) if "label" in examples else len(examples["idx"])
            if self.is_train:

                ret = {'query': [], "idx": []}
                choice = [" yes", " no"]

                for i in range(batch_size):
                    sample = {k: v[i] for k, v in examples.items()}
                
                    context = self.doc_to_text(sample)
                    answer = "\nAnswer:" + choice[self.doc_to_label(sample)]
                
                    ret['query'].append(context + answer)
                    ret['idx'].append(sample['idx'])
            else:
                ret = {'query': [], 'choices': [], 'gold': []}
                for i in range(batch_size):
                    sample = {k: v[i] for k, v in examples.items()}
                
                    context = self.doc_to_text(sample)
                    gold =self.doc_to_label(sample)
                
                    ret['query'].append(context)        
                    ret["choices"].append(["yes", "no"])
                    ret["gold"].append(gold)


        else:
            raise NotImplementedError("MRPCPreProcessor only supports batched=True")
        return ret

def tokenize_batch(
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
        batch: Dict[str, List],
) -> Dict[str, List[Union[Dict[str, List[int]], int]]]:
    texts = batch.pop("query")
    texts = tokenizer(texts, padding="longest", truncation=True, max_length=256, return_tensors='pt')
    ret = {
        "input_ids": texts["input_ids"],
        "attention_mask": texts["attention_mask"],
    }
    return ret