import re
import datasets
from collections import defaultdict
from datasets import load_dataset, Dataset
from functools import partial

# Convert multiple choice datasets into common format
# Question: (context)
# Choices: (chocie)
# Answer: (label)

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/super_glue/boolq/default.yaml
# doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
# doc_to_target: label
# doc_to_choice: ["no", "yes"]

def get_boolqa(split='validation'):
    rawdata = load_dataset('super_glue', 'boolq')[split]
    converted = defaultdict(list)
    for item in rawdata:
        passage = item['passage']
        question = item['question']
        choice = ["no", "yes"]
        answer = item['label']
        full_q = f"{passage}\nQuestion: {question}?\nAnswer:"
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
    return Dataset.from_dict(converted)


# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/glue/rte/default.yaml#L4
def get_rte(split='validation'):
    rawdata = load_dataset('super_glue', 'rte')[split]
    converted = defaultdict(list)
    for item in rawdata:
        premise = item['premise']
        hypothesis = item['hypothesis']
        choice = ["True", "False"]
        answer = item['label']
        full_q = f"{premise}\nQuestion: {hypothesis}?\nAnswer:"
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
    return Dataset.from_dict(converted)
   
# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/hellaswag/hellaswag.yaml
def get_hellaswag(split='valiation', num=1000):
    def preprocess(text):
        text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        text = text.replace(" [title]", ". ")
        text = re.sub("\\[.*?\\]", "", text)
        text = text.replace("  ", " ")
        return text
    def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
        def _process_doc(doc):
            ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
            out_doc = {
                "query": preprocess(doc["activity_label"] + ": " + ctx),
                "choices": [preprocess(ending) for ending in doc["endings"]],
                "gold": int(doc["label"]),
            }
            return out_doc

        return dataset.map(_process_doc)
    
    rawdata =  load_dataset('Rowan/hellaswag')['validation']
    rawdata = process_docs(rawdata)
    converted = defaultdict(list)
    for idx, item in enumerate(rawdata):
        full_q = item['query']
        choice = item['choices']
        answer = item['gold']
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
        if idx == 1000:
            break
    return Dataset.from_dict(converted)

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/arc/arc_easy.yaml
def get_arc(subset='easy'):
    if subset == 'easy':
        rawdata = load_dataset('allenai/ai2_arc', 'ARC-Easy')['validation']
    else:
        rawdata = load_dataset('allenai/ai2_arc', 'ARC-Challenge')['validation']
    
    converted = defaultdict(list)
    for item in rawdata:
        if len(item['choices']['text']) != 4: # Drop wierd sample
            continue
        question = item['question']
        full_q = f"Question: {question}\nAnswer:"
        choice = item['choices']['text']
        answer = "".join(item['choices']['label']).index(item['answerKey'])
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
    return Dataset.from_dict(converted)

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/openbookqa/openbookqa.yaml#L4
def get_openbookqa():
    converted = defaultdict(list)
    rawdata = load_dataset('allenai/openbookqa', 'main')['validation']
    for item in rawdata:
        question = item['question_stem']
        full_q = question
        choice = item['choices']['text']
        answer = "".join(item['choices']['label']).index(item['answerKey'])
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
    return Dataset.from_dict(converted)

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/piqa/piqa.yaml
def get_piqa():
    converted = defaultdict(list)
    rawdata = load_dataset('piqa')['validation']
    for item in rawdata:
        question = item['goal']
        full_q = f"Question: {question}\nAnswer:"
        choice = [item['sol1'], item['sol2']]
        answer = item['label']
        converted['question'].append(full_q)
        converted['choices'].append(choice)
        converted['answer'].append(answer)
    return Dataset.from_dict(converted)

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/truthfulqa/truthfulqa_mc1.yaml
# def get_truthfulqa():
#     converted = defaultdict(list)
#     rawdata = load_dataset('truthful_qa', 'multiple_choice')['validation']

# Format: https://github.com/EleutherAI/lm-evaluation-harness/blob/1980a13c9d7bcdc6e2a19228c203f9f7834ac9b8/lm_eval/tasks/wikitext/wikitext.yaml
def get_wikitext():
    def wikitext_detokenizer(doc):
        string = doc["page"]
        # contractions
        string = string.replace("s '", "s'")
        string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
        # number separators
        string = string.replace(" @-@ ", "-")
        string = string.replace(" @,@ ", ",")
        string = string.replace(" @.@ ", ".")
        # punctuation
        string = string.replace(" : ", ": ")
        string = string.replace(" ; ", "; ")
        string = string.replace(" . ", ". ")
        string = string.replace(" ! ", "! ")
        string = string.replace(" ? ", "? ")
        string = string.replace(" , ", ", ")
        # double brackets
        string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
        string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
        string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
        string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
        string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
        # miscellaneous
        string = string.replace("= = = =", "====")
        string = string.replace("= = =", "===")
        string = string.replace("= =", "==")
        string = string.replace(" " + chr(176) + " ", chr(176))
        string = string.replace(" \n", "\n")
        string = string.replace("\n ", "\n")
        string = string.replace(" N ", " 1 ")
        string = string.replace(" 's", "'s")
        return string

    rawdata = load_dataset("EleutherAI/wikitext_document_level", 'wikitext-2-raw-v1')['test']
    rawdata = rawdata.add_column('newdoc', [wikitext_detokenizer(x) for x in rawdata])
    return rawdata

RETAIN_TASKS = {
    'boolqa': get_boolqa,
    'rte': get_rte,
    'hellaswag': get_hellaswag,
    'arc-easy': partial(get_arc, subset='easy'),
    'arc-challenge': partial(get_arc, subset='challenge'),
    'openbookqa': get_openbookqa,
    'piqa': get_piqa,
}
