# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import ast
import json
import logging
import random
import re
import string

import numpy as np
import pycountry

from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list


logger = logging.getLogger(__name__)


# fmt: off
LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
INTEGER_INDICES = list(map(str, list(range(1, 27))))
# fmt: on


def anli(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['premise']}\nQuestion: {line['hypothesis']} True, False, or Neither?\nAnswer:",
        choices=[" True", " Neither", " False"],
        gold_index=int(line["label"]),
    )


def agieval(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=line["query"],
        choices=[f" {c}" for c in line["choices"]],
        gold_index=line["gold"],
    )


def apps(line, task_name: str = None):
    answer_type = "\nUse Call-Based format\n" if line["starter_code"] != "" else "\nUse Standard Input format\n"
    return Doc(
        task_name=task_name,
        query=f"\nQUESTION:\n{line['question']}\n{line['starter_code']}\n{answer_type}\nANSWER in Python code:\n",
        choices=[json.loads(line["solutions"])],
        gold_index=0,
        specific={"input_output": line["input_output"]},
    )


def arc(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: {line['question']}\nAnswer:",
        choices=[f" {c}" for c in line["choices"]["text"]],
        gold_index=line["choices"]["label"].index(line["answerKey"]),
    )


def arc_with_options_letters_predict(line, task_name: str = None):
    query = f"Question: {line['question']}\n"
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])])
    query += "\nAnswer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["choices"]["text"])],
        gold_index=line["choices"]["label"].index(line["answerKey"]),
    )


def arc_with_options(line, task_name: str = None):
    query = f"Question: {line['question']}\n"
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])])
    query += "\nAnswer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=line["choices"]["text"],
        gold_index=line["choices"]["label"].index(line["answerKey"]),
    )


def arithmetic(line, task_name: str = None):
    return Doc(task_name=task_name, query=line["context"], choices=[line["completion"]], gold_index=[0])


def asdiv(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['body']}\nQuestion:{line['question']}\nAnswer:",
        choices=line["answer"].split(" (")[0],
        gold_index=[0],
    )


def babi_qa(line, task_name: str = None):  # HELM
    def process_path(path: str) -> str:
        """Turn a path string (task 19) from the original format 's,w' to a verbal model-friendly format 'south west'"""
        steps = path.split(",")
        directions = {"s": "south", "n": "north", "e": "east", "w": "west"}
        path = " ".join([directions[step] for step in steps])
        return path

    if isinstance(line["story"], dict):
        line = line["story"]
    else:
        line = json.loads(line["story"])

    queries = []
    story = []
    for type, text, answer in zip(line["type"], line["text"], line["answer"]):
        if type == 1:
            if (
                len(answer) == 3 and re.fullmatch(r"[nswe],[nswe]", answer) is not None
            ):  # task 19, we manage directions
                answer = process_path(answer)
            queries.append(
                Doc(
                    task_name=task_name,
                    query=f"Passage: {' '.join(story)} {text}\nAnswer:",
                    choices=[answer],
                    gold_index=0,
                )
            )
        else:
            story.append(text)
    return queries


def bbh_harness(line, task_name: str = None):
    line = {k: v for k, v in line.items() if v not in [None, ""]}

    task_prefix = line.get("task_prefix", "")
    example_input_prefix = line.get("example_input_prefix", "\nQ: ")
    query = f"{task_prefix}{example_input_prefix}{line['input']}"

    rng = np.random.RandomState(seed=42)
    choice_prefix = line.get("choice_prefix", "\n  choice: ")
    append_choices = bool(line.get("append_choices", True))
    # default
    correct_index = line["target_idx"]
    choices = line["choices"]
    if append_choices:
        choices = list(rng.permutation(sorted(line["choices"])))
        query = f"{query}{choice_prefix}{choice_prefix.join(choices)}"
        gold_item = line["choices"][line["target_idx"]]
        correct_index = choices.index(gold_item)

    example_output_prefix = line.get("example_output_prefix", "\nA: ")
    query = f"{query}{example_output_prefix}"

    return Doc(
        task_name=task_name,
        query=query,
        choices=choices,
        gold_index=correct_index,
        instruction=line.get("task_prefix", None),
    )


def bbh_lighteval(line, task_name: str = None):
    line = {k: v for k, v in line.items() if v is not None}

    query = line.get("task_prefix", "")
    query += line.get("example_input_prefix", "\nQuestion: ")
    query += line["input"]
    query += line.get("choice_prefix", "\n  Choices: ")
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += line.get("example_output_prefix", "\nAnswer: ")

    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["choices"])],
        gold_index=line["target_idx"],
        instruction=line.get("task_prefix", None),
    )


def bbh(line, instruction, choices, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{instruction}Q: {line['input']}\nA:",
        choices=choices,
        gold_index=choices.index(line["target"]),
        instruction=instruction,
    )


def bbh_boolean_expressions(line, task_name: str = None):
    instruction = "Evaluate the result of a random Boolean expression.\n\n"
    choices = ["False", "True"]
    return bbh(line, instruction, choices, task_name)


def bbh_causal_judgment(line, task_name: str = None):
    instruction = "Answer questions about causal attribution.\n\n"
    choices = ["Yes", "No"]
    return bbh(line, instruction, choices, task_name)


def bbh_date_understanding(line, task_name: str = None):
    instruction = "Infer the date from context.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:6]]
    return bbh(line, instruction, choices, task_name)


def bbh_disambiguation_qa(line, task_name: str = None):
    instruction = "Clarify the meaning of sentences with ambiguous pronouns.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:3]]
    return bbh(line, instruction, choices, task_name)


def bbh_dyck_languages(line, task_name: str = None):  # Can only be done in generative
    instruction = "Correctly close a Dyck-n word.\n\n"
    choices = [line["target"]]
    return bbh(line, instruction, choices, task_name)


def bbh_formal_fallacies(line, task_name: str = None):
    instruction = "Distinguish deductively valid arguments from formal fallacies.\n\n"
    choices = ["valid", "invalid"]
    return bbh(line, instruction, choices, task_name)


def bbh_geometric_shapes(line, task_name: str = None):
    instruction = "Name geometric shapes from their SVG paths.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:11]]
    return bbh(line, instruction, choices, task_name)


def bbh_hyperbaton(line, task_name: str = None):
    instruction = "Order adjectives correctly in English sentences.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:2]]
    return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_five_objects(line, task_name: str = None):
    instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:5]]
    return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_seven_objects(line, task_name: str = None):
    instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:7]]
    return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_three_objects(line, task_name: str = None):
    instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:3]]
    return bbh(line, instruction, choices, task_name)


def bbh_movie_recommendation(line, task_name: str = None):
    if line["target"] == "Monsters, Inc":  # this line is not correctly formatted
        logger.warning(
            "One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted."
        )
        return []
    instruction = "Recommend movies similar to the given list of movies.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:6]]
    return bbh(line, instruction, choices, task_name)


def bbh_multistep_arithmetic_two(line, task_name: str = None):
    instruction = "Solve multi-step arithmetic problems.\n\n"  # Can only be done in generative
    choices = [line["target"]]
    return bbh(line, instruction, choices, task_name)


def bbh_navigate(line, task_name: str = None):
    instruction = (
        "Given a series of navigation instructions, determine whether one would end up back at the starting point.\n\n"
    )
    choices = ["Yes", "No"]
    return bbh(line, instruction, choices, task_name)


def bbh_object_counting(line, task_name: str = None):
    instruction = "Questions that involve enumerating objects and asking the model to count them.\n\n"
    choices = [str(i) for i in range(1, 19)]
    return bbh(line, instruction, choices, task_name)


def bbh_penguins_in_a_table(line, task_name: str = None):
    instruction = "Answer questions about a table of penguins and their attributes.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:5]]
    return bbh(line, instruction, choices, task_name)


def bbh_reasoning_about_colored_objects(line, task_name: str = None):
    instruction = "Answer extremely simple questions about the colors of objects on a surface.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:18]]
    return bbh(line, instruction, choices, task_name)


def bbh_ruin_names(line, task_name: str = None):
    if line["target"] in ["dearth, wind, & fire", "rita, sue and bob poo"]:  # line not correctly formatted
        logger.warning("One sample removed from task bbh:ruin_names because its line is incorrectly formatted.")
        return []
    instruction = "Select the humorous edit that 'ruins' the input movie or musical artist name.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:6]]
    return bbh(line, instruction, choices, task_name)


def bbh_salient_translation_error_detection(line, task_name: str = None):
    instruction = "Detect the type of error in an English translation of a German source sentence.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:6]]
    return bbh(line, instruction, choices, task_name)


def bbh_snarks(line, task_name: str = None):
    instruction = 'Determine which of two sentences is sarcastic.\n\nAccording to Cambridge University Dictionary, sarcasm is "the use of remarks that clearly mean the opposite of what they say, made in order to hurt someone\'s feelings or to criticize something in a humorous way." Sarcastic sentences often contain satirical or ironic utterances, hyperboles, ambivalent or witty remarks.\n\n'
    choices = [f"({c})" for c in LETTER_INDICES[:2]]
    return bbh(line, instruction, choices, task_name)


def bbh_sports_understanding(line, task_name: str = None):
    instruction = "Determine whether an artificially constructed sentence relating to sports is plausible or not.\n\n"
    choices = ["yes", "no"]
    return bbh(line, instruction, choices, task_name)


def bbh_temporal_sequences(line, task_name: str = None):
    instruction = "Task description: Answer questions about which times certain events could have occurred.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:4]]
    return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_five_objects(line, task_name: str = None):
    instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:5]]
    return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_seven_objects(line, task_name: str = None):
    instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:7]]
    return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_three_objects(line, task_name: str = None):
    instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
    choices = [f"({c})" for c in LETTER_INDICES[:3]]
    return bbh(line, instruction, choices, task_name)


def bbh_web_of_lies(line, task_name: str = None):
    instruction = "Evaluate a random boolean function expressed as a word problem.\n\n"
    choices = ["Yes", "No"]
    return bbh(line, instruction, choices, task_name)


def bbh_word_sorting(line, task_name: str = None):
    instruction = "Sort a list of words.\n\n"  # Can only be done in generative
    choices = [line["target"]]
    return bbh(line, instruction, choices, task_name)


def bbq(line, task_name: str = None):  # HELM
    query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}"
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "\nAnswer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["choices"])],
        gold_index=int(line["gold_index"]),
    )


def bigbench_helm(line, task_name: str = None):
    if "target" in line:
        return Doc(task_name=task_name, query=line["input"], choices=[line["target"]], gold_index=0)
    choices, gold_ix = [], -1
    if isinstance(line["target_scores"], str):
        line["target_scores"] = ast.literal_eval(line["target_scores"])
    for ix, (choice, score) in enumerate(line["target_scores"].items()):
        choices.append(choice)
        if score == 1:
            gold_ix = ix

    return Doc(task_name=task_name, query=line["input"], choices=choices, gold_index=gold_ix)


def blimp(line, task_name: str = None):
    return Doc(task_name=task_name, query="", choices=[line["sentence_good"], line["sentence_bad"]], gold_index=0)


def blimp_helm(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query="Please select the grammatical sentence.",
        choices=[line["sentence_good"], line["sentence_bad"]],
        gold_index=0,
    )


def bold(line, task_name: str = None):
    return Doc(
        task_name=task_name, query=line["text"], choices=None, gold_index=None
    )  # we only look at the perplexity of the generation > no gold


def boolq(line, task_name: str = None):
    # Exact match
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['passage']}\nQuestion: {line['question']}?\nChoices:\n (A) True\n (B) False\nAnswer: ",
        choices=[
            ("(B)", "(b)", "False", "false", "(B) False", "(b) false"),  # False is label 0
            ("(A)", "(a)", "True", "true", "(A) True", "(a) true"),
        ],
        gold_index=int(line["label"]),
        # "metric": "exact_match_greedy",
    )


def boolq_helm(line, task_name: str = None):
    # remove extra `?`
    question = line["question"][:-1] if line["question"][-2:] == "??" else line["question"]
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['passage']}\nQuestion: {question}\nAnswer:",
        choices=[" Yes", " No"],
        gold_index=["Yes", "No"].index(line["answer"]),
    )


def boolq_helm_contrastset(line, task_name: str = None):
    if line["contrast_inputs"] in [None, ""]:
        return boolq_helm(line)

    return [
        Doc(
            task_name=task_name,
            query=f"{passage}\nQuestion: {question}\nAnswer:",
            choices=["Yes", "No"],
            gold_index=["No", "Yes"].index(line["answer"]),
        )
        for passage, question in zip(line["contrast_inputs"]["passage"], line["contrast_inputs"]["question"])
    ][0]


def boolq_harness(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['passage']}\nQuestion: {line['question']}?\nAnswer:",
        choices=[" no", " yes"],  # False is label 0
        gold_index=int(line["label"]),
    )


def cb(line, task_name: str = None):
    # Log likelihood
    return Doc(
        task_name=task_name,
        query=f"{line['premise']}\nQuestion: {line['hypothesis']}. True, False or Neither?\nAnswer:",
        choices=[" True", " False", " Neither"],
        gold_index=int(line["label"]),
    )


def civil_comments(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['text']}\nIs the passage above toxic?\nAnswer: ",
        choices=[str(line["gold"])],
        gold_index=0,
    )


def cnn_dm(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentence.\n",
        choices=[str(line["summary"])],
        gold_index=0,
        specific={"text": line["article"]},
    )


def cola(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['sentence']}\nQuestion: Does this sentence make sense?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
    )


def commonsense_qa(line, task_name: str = None):
    query = f"The following are multiple choice questions (with answers) about common sense.\nQuestion: {line['question']}\n"
    query += "".join(
        [f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, [f" {c}" for c in line["choices"]["text"]])]
    )
    query += "Answer:"

    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["choices"]["text"])],
        gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
        instruction="The following are multiple choice questions (with answers) about common sense.\n",
    )


def copa(line, task_name: str = None):
    connector = {"cause": "because", "effect": "therefore"}[line["question"]]
    return Doc(
        task_name=task_name,
        query=f"{line['premise'].strip()[:-1]} {connector}",
        choices=[f" {line[c][0].lower()}{line[c][1:]}" for c in ["choice1", "choice2"]],  # removing first cap
        gold_index=int(line["label"]),
    )


def copyright(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=line["prefix"],
        choices=[line["book"]],  # gold reference
        gold_index=0,
    )


def coqa(line, task_name: str = None):
    results = []

    # We return the first question only atm
    for q, a in zip(line["questions"], line["answers"]["input_text"]):
        results.append(Doc(task_name=task_name, query=f"{line['story']} \n\nQ: {q}\n\nA: ", choices=[a], gold_index=0))
    return results


def covid_dialogue(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Generate a response given a patient's questions and concerns.\nPatient: {line['query']}\nDoctor: ",
        choices=[line["answer"]],
        gold_index=0,
        instruction="Generate a response given a patient's questions and concerns.\n",
    )


def crows_pair(line, task_name: str = None):
    return Doc(task_name=task_name, query="", choices="", gold_index="", instruction="")


def dyck_language(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n Input: {line['input']}",
        choices=[line["output"]],
        gold_index=0,
        instruction="Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n ",
    )


def drop(line, task_name: str = None):
    # For the Harness new format, v0.0.1
    def _flatten_validated_answers(validated_answers):
        """Flattens a dict of lists of validated answers.
        {"number": ['1', '8'], ...}
        -> [{"number": ['1'], ...}, {"number": ['8'], ...}]
        """
        valid_answers = []
        for i in range(len(validated_answers["number"])):
            valid_answers.append(
                {
                    "number": validated_answers["number"][i],
                    "date": validated_answers["date"][i],
                    "spans": validated_answers["spans"][i],
                }
            )
        return valid_answers

    def parse_answer(answer):
        # remove this tuple system, legacy from the harness, only here to allow to use answers as key in the below set
        if answer["number"] != "":
            return (str(answer["number"]),)
        if answer["spans"] != []:
            return tuple(answer["spans"])
        return (" ".join([answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]).strip(),)

    answers = []
    answers_set = set()
    candidates = [line["answer"]] + _flatten_validated_answers(line["validated_answers"])
    for candidate in candidates:
        answer = parse_answer(candidate)
        if answer in answers_set:
            continue
        answers.append(answer)
        # answers.extend(as_list(answer))
        answers_set.add(answer)

    is_few_shots = line.get("__few_shots", False)  # We are adding few shots

    return Doc(
        task_name=task_name,
        query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:",
        choices=[f"{' ' if is_few_shots else ''}{', '.join(a)}" for a in answers],
        gold_index=list(range(len(answers))),
        specific={"golds_no_preprocessing": [as_list(a) for a in answers]},
    )


def empathetic_dialogue(line, task_name: str = None):
    return Doc(
        task_name=task_name, query=f"{line['input']}\nBEGIN DIALOGUE\n", choices=[line["references"]], gold_index=0
    )


def entity_data_imputation(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"What is the missing value?\n{line['text']}\nAnswer:",
        choices=[line["gold"]],
        gold_index=0,
        instruction="What is the missing value?\n",
    )


def entity_matching(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Are Product A and Product B the same? Yes or No?\nProduct A is {line['productA']}. Product B is {line['productB']}. Are A and B the same?\nAnswer:",
        choices=["No", "Yes"],
        gold_index=int(line["same"]),
        instruction="Are Product A and Product B the same? Yes or No?\n",
    )


def ethics_commonsense(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['input']}\nQuestion: Is this wrong?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
    )


def ethics_deontology(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: Would most people believe this reasonable or unreasonable to say? \"{line['scenario']} {line['excuse']}\"\nAnswer:",
        choices=[" unreasonable", " reasonable"],
        gold_index=int(line["label"]),
        instruction="Question: Would most people believe this reasonable or unreasonable to say?",
    )


def ethics_justice(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: Would most people believe this reasonable or unreasonable to say? \"{line['scenario']}\"\nAnswer:",
        choices=[" unreasonable", " reasonable"],
        gold_index=int(line["label"]),
        instruction="Question: Would most people believe this reasonable or unreasonable to say?",
    )


def ethics_utilitarianism(line, task_name: str = None):
    rnd = random.Random(line["activity"])
    scenarios = [line["activity"], line["baseline"]]
    ordering = [0, 1]
    rnd.shuffle(ordering)
    return Doc(
        task_name=task_name,
        query=f"Scenario 1: {scenarios[ordering[0]]}\nScenario 2: {scenarios[ordering[1]]}\nQuestion: Is Scenario 1 preferable?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(ordering.index(0) == 0),
    )


def ethics_virtue(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Sentence: {line['scenario']}\nQuestion: Does the character in this sentence exhibit the trait \"{line['trait']}\"?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
    )


def gpqa(line, task_name: str = None):
    gold_index = random.randint(0, 3)
    choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
    choices.insert(gold_index, line["Correct Answer"])

    instruction = "Select the correct answer to the following questions.\n\n"

    query = f"Question: {line['Question']}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)])
    query += "Answer: "

    return Doc(
        task_name=task_name,
        query=f"{instruction}{query}",
        choices=LETTER_INDICES[: len(choices)],
        gold_index=gold_index,
        instruction=instruction,
    )


def gsm8k(line, task_name: str = None):
    # Has special analysis in metric for number decomposiition
    return Doc(
        task_name=task_name,
        query=f"Question: {line['question']}\nAnswer:",
        choices=[f" {line['answer']}"],
        gold_index=0,
    )


def gsm8k_helm(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Q: {line['question']}\nA: ",
        choices=[line["answer"].replace("####", "The answer is").replace("\n", " ") + "."],
        gold_index=0,
    )


def headqa(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: {line['qtext']}\nAnswer:",
        choices=[f" {answer['atext']}" for answer in line["answers"]],
        gold_index=int(line["ra"]) - 1,
    )


def hellaswag_preprocess(
    text: str,
    wikihow_artifacts: list[str] = [" [title]"],
    truncate_dots: bool = False,
    strip_text: bool = False,
    dot_replacement: str = ". ",
):
    """Comes from LM Eval Harness"""
    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
    for wikihow_artifact in wikihow_artifacts:
        text = text.replace(wikihow_artifact, dot_replacement)
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    if truncate_dots:
        text = text.replace(r"\.+", r"\.")
    if strip_text:
        text = text.strip()
    return text


def hellaswag_harness(line, task_name: str = None):
    ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
    return Doc(
        task_name=task_name,
        query=hellaswag_preprocess(line["activity_label"] + ": " + ctx),
        choices=[hellaswag_preprocess(ending) for ending in line["endings"]],
        gold_index=int(line["label"]) if line["label"] != "" else -1,  # -1 for test
        # "metric": "choices_loglikelihood",
    )


def hellaswag_generative(line, task_name: str = None):
    query = "The following are multiple choice questions (with answers) about common sense.\n\n"
    query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["endings"])])
    query += "Answer:"

    gold_ix = int(line["label"]) if line["label"] != "" else -1
    return Doc(
        task_name=task_name,
        query=query,
        choices=[" " + i for i in LETTER_INDICES[: len(line["endings"])]],
        gold_index=gold_ix,  # -1 for test,
        instruction="The following are multiple choice questions (with answers) about common sense.\n\n",
    )


def humaneval(line, task_name: str = None):
    # "test_cases": line["test"]
    return Doc(
        task_name=task_name,
        query=line["prompt"],
        choices=[line["canonical_solution"]],
        gold_index=0,
        specific={key: line[key] for key in ["prompt", "test", "entry_point", "task_id"]},
    )


def humaneval_for_code_models(line, task_name: str = None):
    # We need to remove ending "\n" as it's never tokenized on its own but rather as "\n\t"
    query = line["Doc"][:-1] if line["Doc"][-1:] == "\n" else line["Doc"]
    return Doc(task_name=task_name, query=query, choices=[line["canonical_solution"]], gold_index=0, specific=line)


def imdb(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['input']}\nSentiment: ",
        choices=["Positive", "Negative"],
        gold_index=["Positive", "Negative"].index(line["reference"]),
    )


def imdb_contrastset(line, task_name: str = None):
    if line["contrast_input"] is None or line["contrast_references"] is None:
        return imdb(line)

    return Doc(
        task_name=task_name,
        query=f"Passage: {line['contrast_inputs']}\nSentiment: ",
        choices=["Positive", "Negative"],
        gold_index=["Positive", "Negative"].index(line["contrast_references"]),
    )


def lambada_cloze(line, task_name: str = None):
    query, choice = line["text"].rsplit(" ", 1)
    return Doc(
        task_name=task_name,
        query=f"{query} ____. ->",
        gold_index=0,
        choices=[f" {choice}"],
    )


def lambada(line, task_name: str = None):
    query, choice = line["text"].rsplit(" ", 1)
    return Doc(
        task_name=task_name,
        query=query,
        gold_index=0,
        choices=[f" {choice}"],
    )


def legal_support(line, task_name: str = None):
    query = f"Which statement best supports the passage?\nPassage: {line['context']}\n"
    query += "".join(
        [
            f"{key}. {choice}\n"
            for key, choice in zip(
                ["a", "b"], [line["citation_a"]["parenthetical"], line["citation_b"]["parenthetical"]]
            )
        ]
    )
    query += "Answer:"

    return Doc(
        task_name=task_name,
        query=query,
        choices=["a", "b"],
        gold_index=["a", "b"].index(line["label"]),
        instruction="Which statement best supports the passage?\n",
    )


def lex_glue(line, instruction, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{instruction}\nPassage: {line['input']}\nAnswer: ",
        choices=line["references"],
        gold_index=[line["references"].index(item) for item in line["gold"]],
        instruction=instruction + "\n",
    )


def lex_glue_ecthr_a(line, task_name: str = None):
    instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of the ECtHR that were violated (if any)."
    return lex_glue(line, instruction, task_name)


def lex_glue_ecthr_b(line, task_name: str = None):
    instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of ECtHR that were allegedly violated (considered by the court)."
    return lex_glue(line, instruction, task_name)


def lex_glue_scotus(line, task_name: str = None):
    instruction = "In this task, you are given a case heard at the Supreme Court of the United States (SCOTUS). Predict the relevant issue area."
    return lex_glue(line, instruction, task_name)


def lex_glue_eurlex(line, task_name: str = None):
    instruction = "In this task, you are given an EU law document published in the EUR-Lex portal. Predict the relevant EuroVoc concepts."
    return lex_glue(line, instruction, task_name)


def lex_glue_ledgar(line, task_name: str = None):
    instruction = "In this task, you are given a contract provision \nfrom contracts obtained from US Securities and Exchange Commission (SEC) filings. Predict the main topic."
    return lex_glue(line, instruction, task_name)


def lex_glue_unfair_tos(line, task_name: str = None):
    instruction = "In this task, you are given a sentence \nfrom a Terms of Service (ToS) document from on-line platforms. Predict the types of unfair contractual terms"
    return lex_glue(line, instruction, task_name)


def lex_glue_case_hold(line, task_name: str = None):
    instruction = "In this task, you are given an excerpt from a court decision, \ncontaining a reference to a particular case, while the holding statement is masked out. Predict the index of the holding statement fitting in the context at <HOLDING> from a selection of five choices."
    return lex_glue(line, instruction, task_name)


def lextreme(line, instruction, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{instruction}\nPassage: {line['input']}\nAnswer: ",
        choices=line["references"],
        gold_index=[line["references"].index(item) for item in line["gold"]],
        instruction=instruction + "\n",
    )


def lextreme_brazilian_court_decisions_judgment(line, task_name: str = None):
    instruction = (
        "In this task, you are given the case description "
        "from a decision heard at the State Supreme Court of Alagoas (Brazil). "
        "Predict the judgment of the case "
        "(no: The appeal was denied, "
        "partial: For partially favourable decisions, "
        "yes: For fully favourable decisions)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_brazilian_court_decisions_unanimity(line, task_name: str = None):
    instruction = (
        "In this task, you are given the case description "
        "from a decision heard at the State Supreme Court of Alagoas (Brazil). "
        "Predict the unanimity of the case (unanimity, not-unanimity, not_determined)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_german_argument_mining(line, task_name: str = None):
    instruction = (
        "In this task, you are given sentences from German court decisions. "
        "Predict the major component of German Urteilsstil "
        "(conclusion: Overall result, "
        "definition: Abstract legal facts and consequences, "
        "subsumption: Determination sentence / Concrete facts, "
        "other: Anything else)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_greek_legal_code_chapter(line, task_name: str = None):
    instruction = (
        "In this task, you are given a Greek legislative document. "
        "Predict the chapter level category of the "
        "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to."
    )
    return lextreme(line, instruction, task_name)


def lextreme_greek_legal_code_subject(line, task_name: str = None):
    instruction = (
        "In this task, you are given a Greek legislative document. "
        "Predict the subject level category of the "
        "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to."
    )

    return lextreme(line, instruction, task_name)


def lextreme_greek_legal_code_volume(line, task_name: str = None):
    instruction = (
        "In this task, you are given a Greek legislative document. "
        "Predict the volume level category of the "
        "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to."
    )
    return lextreme(line, instruction, task_name)


def lextreme_swiss_judgment_prediction(line, task_name: str = None):
    instruction = (
        "In this task, you are given the facts description "
        "from a decision heard at the Swiss Federal Supreme Court. "
        "Predict the judgment of the case (approval or dismissal)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_online_terms_of_service_unfairness_levels(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence "
        "from a Terms of Service (ToS) document. "
        "Predict the unfairness level of the sentence (potentially_unfair, clearly_unfair, clearly_fair, untagged)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_online_terms_of_service_clause_topics(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence "
        "from a Terms of Service (ToS) document. "
        "Predict the clause topics of the sentence "
        "(0: Arbitration, "
        "1: Unilateral change, "
        "2: Content removal, "
        "3: Jurisdiction, "
        "4: Choice of law, "
        "5: Limitation of liability, "
        "6: Unilateral termination, "
        "7: Contract by using, "
        "8: Privacy included)"
    )
    return lextreme(line, instruction, task_name)


def lextreme_covid19_emergency_event(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence from a European legislative document. "
        "Predict the applicable measurements against COVID-19 "
        "(0: State of Emergency, "
        "1: Restrictions of fundamental rights and civil liberties, "
        "2: Restrictions of daily liberties, "
        "3: Closures / lockdown, "
        "4: Suspension of international cooperation and commitments, "
        "5: Police mobilization, "
        "6: Army mobilization, "
        "7: Government oversight)"
    )

    return lextreme(line, instruction, task_name)


def lextreme_multi_eurlex_level_1(line, task_name: str = None):
    instruction = (
        "In this task, you are given a document from an EU law. "
        "Predict the level 1 concept in the EUROVOC taxonomy."
    )
    return lextreme(line, instruction, task_name)


def lextreme_multi_eurlex_level_2(line, task_name: str = None):
    instruction = (
        "In this task, you are given a document from an EU law. "
        "Predict the level 2 concept in the EUROVOC taxonomy."
    )
    return lextreme(line, instruction, task_name)


def lextreme_multi_eurlex_level_3(line, task_name: str = None):
    instruction = (
        "In this task, you are given a document from an EU law. "
        "Predict the level 3 concept in the EUROVOC taxonomy."
    )

    return lextreme(line, instruction, task_name)


def lextreme_greek_legal_ner(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence from Greek legislation. "
        "Predict the named entity type for each token."
    )
    return lextreme(line, instruction, task_name)


def lextreme_legalnero(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence from Romanian legislation. "
        "Predict the named entity type for each token."
    )
    return lextreme(line, instruction, task_name)


def lextreme_lener_br(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence "
        "from Brazilian legal documents (court decisions and legislation). "
        "Predict the named entity type for each token."
    )
    return lextreme(line, instruction, task_name)


def lextreme_mapa_coarse(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence from the EUR-Lex database. "
        "Predict the coarse grained named entity type for each token."
    )
    return lextreme(line, instruction, task_name)


def lextreme_mapa_fine(line, task_name: str = None):
    instruction = (
        "In this task, you are given a sentence from the EUR-Lex database. "
        "Predict the fine grained named entity type for each token."
    )
    return lextreme(line, instruction, task_name)


def legal_summarization(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle: {line['article']}\n\nSummarize the above article.\n",
        gold_index=0,
        choices=[line["summary"]],
        specific={"text": line["article"]},
    )


def mgsm(line, question_key, answer_key, task_name: str = None):
    if line["answer"] is not None:
        query = f"{line['question']}\n{answer_key}"
        gold = f" {line['answer'][len(answer_key) + 1:]}"
    else:
        query = f"{question_key} {line['question']}\n{answer_key}"
        gold = f" {str(line['answer_number'])}"
    return Doc(task_name=task_name, query=query, choices=[gold], gold_index=0)


def mgsm_en(line, task_name: str = None):
    question_key = "Question:"
    answer_key = "Step-by-Step Answer:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_es(line, task_name: str = None):
    question_key = "Pregunta:"
    answer_key = "Respuesta paso a paso:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_fr(line, task_name: str = None):
    question_key = "Question:"
    answer_key = "R\u00e9ponse \u00e9tape par \u00e9tape :"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_de(line, task_name: str = None):
    question_key = "Frage:"
    answer_key = "Schritt-f\u00fcr-Schritt-Antwort:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_ru(line, task_name: str = None):
    question_key = "\u0417\u0430\u0434\u0430\u0447\u0430:"
    answer_key = "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_zh(line, task_name: str = None):
    question_key = "\u95ee\u9898:"
    answer_key = "\u9010\u6b65\u89e3\u7b54:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_ja(line, task_name: str = None):
    question_key = "\u554f\u984c:"
    answer_key = "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_th(line, task_name: str = None):
    question_key = "\u0e42\u0e08\u0e17\u0e22\u0e4c:"
    answer_key = "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_sw(line, task_name: str = None):
    question_key = "Swali:"
    answer_key = "Jibu la Hatua kwa Hatua:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_bn(line, task_name: str = None):
    question_key = "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:"
    answer_key = "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:"
    return mgsm(line, question_key, answer_key, task_name)


def mgsm_te(line, task_name: str = None):
    question_key = "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:"
    answer_key = "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:"
    return mgsm(line, question_key, answer_key, task_name)


def multilexsum(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle: {line['article']}\n\nSummarize the above article in 2 sentences.\n",
        gold_index=0,
        choices=[line["summary"]],
        specific={"text": line["article"]},
    )


def logiqa(line, task_name: str = None):
    query = f"Passage: {line['context']}\nQuestion: {line['question']}\nChoices:\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(["A", "B", "C", "D"], line["options"])])
    query += "Answer:"

    return Doc(
        task_name=task_name,
        query=query,
        choices=[f" {c}" for c in line["options"]],
        gold_index=["a", "b", "c", "d"].index(line["label"]),
    )


def lsat_qa(line, task_name: str = None):
    query = f"The following are multiple choice questions (with answers).\nPassage: {line['passage']}\nQuestion: {line['question']}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["references"])])
    query += "Answer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["references"])],
        gold_index=line["gold_index"],
        instruction="The following are multiple choice questions (with answers).\n",
    )


def math(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Problem: {line['problem']}\nAnswer:",
        gold_index=0,
        choices=[f" {line['solution']}"],
    )


def math_cot(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['problem']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
        gold_index=0,
        choices=[f" {line['solution']}"],
    )


def math_helm(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Given a mathematics problem, determine the answer. Simplify your answer as much as possible.\nProblem: {line['problem']}\nAnswer: $\n###\n",
        gold_index=0,
        choices=[f" {line['solution']}"],
        instruction="Given a mathematics problem, determine the answer. Simplify your answer as much as possible.\n",
    )


def mathqa(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Questions: {line['Problem']}\nAnswer",
        choices=[
            c[4:].rstrip(" ,")  # todo: check below regew
            for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", line["options"])
        ],
        gold_index=["a", "b", "c", "d", "e"].index(line["correct"]),
    )


def me_q_sum(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle:{line['query']}\n\nSummarize the above article in 1 sentence.\n",
        gold_index=0,
        choices=[line["answer"]],
    )


def med_dialog(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle:{line['src']}\n\nSummarize the above article in 1 sentence.\n",
        gold_index=0,
        choices=[line["tgt"]],
    )


def med_mcqa(line, task_name: str = None):
    query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n"
    query += "".join(
        [
            f"{key}. {choice}\n"
            for key, choice in zip(LETTER_INDICES, [line["opa"], line["opb"], line["opc"], line["opd"]])
        ]
    )
    query += "Answer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[:4],
        gold_index=line["cop"] - 1,
        instruction="Give a letter answer among A, B, C or D.\n",
    )


def med_paragraph_simplification(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle:{line['query']}\n\nSummarize the above article in 10 sentences.\n",
        gold_index=0,
        choices=[line["answer"]],
    )


def med_qa(line, task_name: str = None):
    query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n"
    query += "".join([f"{option['key']}. {option['value']}\n" for option in line["options"]])
    query += "Answer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=[opt["key"] for opt in line["options"]],
        gold_index=LETTER_INDICES.index(line["answer_idx"]),
        instruction="Give a letter answer among A, B, C or D.\n",
    )


def mmlu(line, topic, task_name: str = None):
    query = f"The following are multiple choice questions (with answers) about  {topic.replace('_', ' ')}.\n\n"
    query += line["question"] + "\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "Answer:"

    gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
    is_few_shots = line.get("__few_shots", False)  # We are adding few shots

    return Doc(
        task_name=task_name,
        query=query,
        choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"],
        gold_index=gold_ix,
        instruction=f"The following are multiple choice questions (with answers) about  {topic.replace('_', ' ')}.\n\n",
    )


def custom_mmlu_thom(line, task_name: str = None):
    topic = "abstract_algebra"
    query = f"The following are multiple choice questions (with answers) about  {topic.replace('_', ' ')}.\n\n"
    query += line["question"] + "\n"
    choices = [line["option_1"], line["option_2"], line["option_3"], line["option_4"]]
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)])
    query += "Answer:"

    gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
    is_few_shots = line.get("__few_shots", False)  # We are adding few shots

    return Doc(
        task_name=task_name,
        query=query,
        choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"],
        gold_index=gold_ix,
        instruction=f"The following are multiple choice questions (with answers) about  {topic.replace('_', ' ')}.\n\n",
    )


def mmlu_abstract_algebra(line, task_name: str = None):
    return mmlu(line, "abstract_algebra", task_name)


def mmlu_anatomy(line, task_name: str = None):
    return mmlu(line, "anatomy", task_name)


def mmlu_astronomy(line, task_name: str = None):
    return mmlu(line, "astronomy", task_name)


def mmlu_business_ethics(line, task_name: str = None):
    return mmlu(line, "business_ethics", task_name)


def mmlu_clinical_knowledge(line, task_name: str = None):
    return mmlu(line, "clinical_knowledge", task_name)


def mmlu_college_biology(line, task_name: str = None):
    return mmlu(line, "college_biology", task_name)


def mmlu_college_chemistry(line, task_name: str = None):
    return mmlu(line, "college_chemistry", task_name)


def mmlu_college_computer_science(line, task_name: str = None):
    return mmlu(line, "college_computer_science", task_name)


def mmlu_college_mathematics(line, task_name: str = None):
    return mmlu(line, "college_mathematics", task_name)


def mmlu_college_medicine(line, task_name: str = None):
    return mmlu(line, "college_medicine", task_name)


def mmlu_college_physics(line, task_name: str = None):
    return mmlu(line, "college_physics", task_name)


def mmlu_computer_security(line, task_name: str = None):
    return mmlu(line, "computer_security", task_name)


def mmlu_conceptual_physics(line, task_name: str = None):
    return mmlu(line, "conceptual_physics", task_name)


def mmlu_econometrics(line, task_name: str = None):
    return mmlu(line, "econometrics", task_name)


def mmlu_electrical_engineering(line, task_name: str = None):
    return mmlu(line, "electrical_engineering", task_name)


def mmlu_elementary_mathematics(line, task_name: str = None):
    return mmlu(line, "elementary_mathematics", task_name)


def mmlu_formal_logic(line, task_name: str = None):
    return mmlu(line, "formal_logic", task_name)


def mmlu_global_facts(line, task_name: str = None):
    return mmlu(line, "global_facts", task_name)


def mmlu_high_school_biology(line, task_name: str = None):
    return mmlu(line, "high_school_biology", task_name)


def mmlu_high_school_chemistry(line, task_name: str = None):
    return mmlu(line, "high_school_chemistry", task_name)


def mmlu_high_school_computer_science(line, task_name: str = None):
    return mmlu(line, "high_school_computer_science", task_name)


def mmlu_high_school_european_history(line, task_name: str = None):
    return mmlu(line, "high_school_european_history", task_name)


def mmlu_high_school_geography(line, task_name: str = None):
    return mmlu(line, "high_school_geography", task_name)


def mmlu_high_school_government_and_politics(line, task_name: str = None):
    return mmlu(line, "high_school_government_and_politics", task_name)


def mmlu_high_school_macroeconomics(line, task_name: str = None):
    return mmlu(line, "high_school_macroeconomics", task_name)


def mmlu_high_school_mathematics(line, task_name: str = None):
    return mmlu(line, "high_school_mathematics", task_name)


def mmlu_high_school_microeconomics(line, task_name: str = None):
    return mmlu(line, "high_school_microeconomics", task_name)


def mmlu_high_school_physics(line, task_name: str = None):
    return mmlu(line, "high_school_physics", task_name)


def mmlu_high_school_psychology(line, task_name: str = None):
    return mmlu(line, "high_school_psychology", task_name)


def mmlu_high_school_statistics(line, task_name: str = None):
    return mmlu(line, "high_school_statistics", task_name)


def mmlu_high_school_us_history(line, task_name: str = None):
    return mmlu(line, "high_school_us_history", task_name)


def mmlu_high_school_world_history(line, task_name: str = None):
    return mmlu(line, "high_school_world_history", task_name)


def mmlu_human_aging(line, task_name: str = None):
    return mmlu(line, "human_aging", task_name)


def mmlu_human_sexuality(line, task_name: str = None):
    return mmlu(line, "human_sexuality", task_name)


def mmlu_international_law(line, task_name: str = None):
    return mmlu(line, "international_law", task_name)


def mmlu_jurisprudence(line, task_name: str = None):
    return mmlu(line, "jurisprudence", task_name)


def mmlu_logical_fallacies(line, task_name: str = None):
    return mmlu(line, "logical_fallacies", task_name)


def mmlu_machine_learning(line, task_name: str = None):
    return mmlu(line, "machine_learning", task_name)


def mmlu_management(line, task_name: str = None):
    return mmlu(line, "management", task_name)


def mmlu_marketing(line, task_name: str = None):
    return mmlu(line, "marketing", task_name)


def mmlu_medical_genetics(line, task_name: str = None):
    return mmlu(line, "medical_genetics", task_name)


def mmlu_miscellaneous(line, task_name: str = None):
    return mmlu(line, "miscellaneous", task_name)


def mmlu_moral_disputes(line, task_name: str = None):
    return mmlu(line, "moral_disputes", task_name)


def mmlu_moral_scenarios(line, task_name: str = None):
    return mmlu(line, "moral_scenarios", task_name)


def mmlu_nutrition(line, task_name: str = None):
    return mmlu(line, "nutrition", task_name)


def mmlu_philosophy(line, task_name: str = None):
    return mmlu(line, "philosophy", task_name)


def mmlu_prehistory(line, task_name: str = None):
    return mmlu(line, "prehistory", task_name)


def mmlu_professional_accounting(line, task_name: str = None):
    return mmlu(line, "professional_accounting", task_name)


def mmlu_professional_law(line, task_name: str = None):
    return mmlu(line, "professional_law", task_name)


def mmlu_professional_medicine(line, task_name: str = None):
    return mmlu(line, "professional_medicine", task_name)


def mmlu_professional_psychology(line, task_name: str = None):
    return mmlu(line, "professional_psychology", task_name)


def mmlu_public_relations(line, task_name: str = None):
    return mmlu(line, "public_relations", task_name)


def mmlu_security_studies(line, task_name: str = None):
    return mmlu(line, "security_studies", task_name)


def mmlu_sociology(line, task_name: str = None):
    return mmlu(line, "sociology", task_name)


def mmlu_us_foreign_policy(line, task_name: str = None):
    return mmlu(line, "us_foreign_policy", task_name)


def mmlu_virology(line, task_name: str = None):
    return mmlu(line, "virology", task_name)


def mmlu_world_religions(line, task_name: str = None):
    return mmlu(line, "world_religions", task_name)


def mmlu_harness(line, task_name: str = None):
    topic = line["subject"]
    query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
    query += line["question"] + "\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "Answer:"

    gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]

    return Doc(
        task_name=task_name,
        query=query,
        choices=[" A", " B", " C", " D"],
        gold_index=gold_ix,
        instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
    )


def mmlu_helm(line, task_name: str = None):
    subject = line["subject"]
    query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}"
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "\nAnswer:"

    gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]

    return Doc(
        task_name=task_name,
        query=query,
        choices=[" A", " B", " C", " D"],
        gold_index=gold_ix,
        fewshot_sorting_class=line["choices"][gold_ix],
        instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n",
    )


def mmlu_qa_abstract_algebra(line, task_name: str = None):
    return mmlu_qa(line, "abstract_algebra", task_name)


def mmlu_qa_college_chemistry(line, task_name: str = None):
    return mmlu_qa(line, "college_chemistry", task_name)


def mmlu_qa_global_facts(line, task_name: str = None):
    return mmlu_qa(line, "global_facts", task_name)


def mmlu_qa_miscellaneous(line, task_name: str = None):
    return mmlu_qa(line, "miscellaneous", task_name)


def mmlu_qa_nutrition(line, task_name: str = None):
    return mmlu_qa(line, "nutrition", task_name)


def mmlu_qa_us_foreign_policy(line, task_name: str = None):
    return mmlu_qa(line, "us_foreign_policy", task_name)


def mmlu_qa(line, subject, task_name: str = None):
    query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\nQuestion: {line['question']}"
    query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "\nAnswer:"

    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B", "C", "D"],  # line["choices"],
        gold_index=LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"],
        instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n",
    )


def mnli(line, task_name: str = None):
    hypothesis = line["hypothesis"].strip() + ("" if line["hypothesis"].strip().endswith(".") else ".")
    return Doc(
        task_name=task_name,
        query=f"{line['premise']}\nQuestion: {hypothesis} True, False or Neither?\nAnswer:",
        choices=[" True", " Neither", " False"],
        gold_index=int(line["label"]),
    )


def mrpc(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Do both sentences mean the same thing?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
    )


def multirc(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['paragraph']}\nQuestion: {line['question']}\nAnswer:",
        choices=[f" {line['answer']}\nIs the answer correct? yes", f" {line['answer']}\nIs the answer correct? no"],
        gold_index=0 if line["label"] else 1,
    )


def musr(line, task_name: str = None):
    choices = ast.literal_eval(line["choices"])

    query = line["narrative"] + "\n\n"
    query += line["question"] + "\n\n"
    for i, choice in enumerate(choices):
        query += f"{i + 1} - {choice}\n"
    query += "Answer:"

    return Doc(task_name=task_name, query=query, choices=choices, gold_index=line["answer_index"])


def mutual(line, task_name: str = None):
    def clean(text):
        replace_list = [(" '", "'"), (" \n", "\n"), ("\n ", "\n"), (" n't", "n't"), ("`` ", '"'), ("''", '"')]
        replace_list.extend([(" :", ":"), (" ;", ";"), (" !", "!"), (" ?", "?"), (" ,", ","), (" .", ".")])
        for in_str, out_str in replace_list:
            text = text.replace(in_str, out_str)
        return text

    return Doc(
        task_name=task_name,
        query=clean(line["article"]),
        choices=[f" {clean(option)}" for option in line["options"]],
        gold_index=["A", "B", "C", "D"].index(line["answers"]),
    )


def narrativeqa(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:",
        gold_index=list(range(len(line["references"]))),
        choices=[[str(a) for a in line["references"]]],
    )


def natural_qa_closedbook(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: {line['question']}\nAnswer: ",
        gold_index=0,
        choices=[line["short_answers"]],
    )


def natural_qa_openbook_longans(line, task_name: str = None):
    ans_idx = random.randint(0, len(line["short_answers"]) - 1)
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['long_answers'][ans_idx]}\n\nQuestion: {line['question']}\nAnswer: ",
        gold_index=0,
        choices=[line["short_answers"]],
    )


def natural_qa_openbook_wiki(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Title: {line['title']}\n\nPassage: {line['document']}\n\n Question: {line['question']}\nAnswer: ",
        gold_index=0,
        choices=[line["short_answers"]],
    )


def newsqa(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['text']}\nQuestion {line['questions']}\nAnswer: ",
        gold_index=0,
        choices=[line["gold"]],
    )


def numeracy(line, task_name: str = None):
    name = ["x", "y", "z"]
    vars = ""
    for ix, value in enumerate(line["vars"]):
        vars += f"{name[ix]} {value}, "
    vars += name[ix + 1]

    return Doc(task_name=task_name, query=f"{line['equation']}, {vars}", gold_index=0, choices=[str(line["output"])])


def openbookqa(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['question_stem']}",
        choices=[f" {c}" for c in line["choices"]["text"]],
        gold_index=["A", "B", "C", "D", "E"].index(line["answerKey"].strip()),
        # "metric": "choices_loglikelihood",
    )


def openbookqa_helm(line, task_name: str = None):
    query = "The following are multiple choice questions (with answers) about common sense.\n"
    query += f"Question: {line['question_stem']}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])])
    query += "Answer: "

    gold_ix = ["A", "B", "C", "D", "E"].index(line["answerKey"].strip())
    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B", "C", "D", "E"],
        gold_index=gold_ix,
        instruction="The following are multiple choice questions (with answers) about common sense.\n",
    )


def piqa_harness(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question: {line['goal']}\nAnswer:",
        choices=[f" {line['sol1']}", f" {line['sol2']}"],
        gold_index=int(line["label"]),
        # "metric": "choices_loglikelihood",
    )


def piqa_helm(line, task_name: str = None):
    query = "The following are multiple choice questions (with answers) about common sense.\n"
    query += f"Question: {line['goal']}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, [line["sol1"], line["sol2"]])])
    query += "Answer: "

    gold_ix = int(line["label"])
    is_few_shots = line.get("__few_shots", False)
    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B"] if not is_few_shots else [line["sol1"], line["sol2"]],
        gold_index=gold_ix,
        instruction="The following are multiple choice questions (with answers) about common sense.\n",
    )


def prost(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['context']}\nQuestion: {line['ex_question']}\nAnswer:",
        choices=[f" {line['A']}", f" {line['B']}", f" {line['C']}", f" {line['D']}"],
        gold_index=int(line["label"]),
    )


def pubmed_qa(line, task_name: str = None):
    contexts = "\n".join(line["context"]["contexts"])
    return Doc(
        task_name=task_name,
        query=f"Abstract: {contexts}\nQuestion: {line['question']}\nAnswer:",
        choices=[" yes", " no", " maybe"],
        gold_index=["yes", "no", "maybe"].index(line["final_decision"]),
    )


def pubmed_qa_helm(line, task_name: str = None):
    query = "Answer A for yes, B for no or C for maybe.\n\nContext: "
    query += "\n".join(
        [
            f"{label.title()}. {context}"
            for label, context in zip(line["context"]["labels"], line["context"]["contexts"])
        ]
    )
    query += f"\n\nQuestion: {line['question']}\nAnswer: "
    gold_ix = ["yes", "no", "maybe"].index(line["final_decision"])
    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B", "C"],
        gold_index=gold_ix,
    )


def qa4mre(line, task_name: str = None):
    source = line["document_str"].strip().replace("'", "'")
    return Doc(
        task_name=task_name,
        query=f"{source}\nQuestion: {line['question_str']}\nAnswer:",
        choices=[f" {answer}" for answer in line["answer_options"]["answer_str"]],
        gold_index=int(line["correct_answer_id"]) - 1,
    )


def qasper(line, task_type="generative", task_name: str = None):
    def extract_answer(answer_choices):
        keys = ["free_form_answer", "extractive_spans"]
        for k in keys:
            if answer_choices[k]:
                return answer_choices[k]
        if answer_choices["unanswerable"]:
            return "unanswerable"
        if answer_choices["yes_no"]:
            return "yes"
        return "no"

    results = []
    for question, answer_list in zip(line["qas"]["question"], line["qas"]["answers"]):
        for answer in answer_list["answer"]:
            gold = extract_answer(answer)
            # Qasper is either evaluated with loglikelihoods for yes no questions, or generative acc for the rest
            if gold == "yes" or gold == "no":  # log likelihood
                results.append(
                    Doc(
                        task_name=task_name,
                        query=f"TITLE: {line['title']}\nABSTRACT: {line['abstract']}\n\nQ: {question}\n\nA:",
                        gold_index=int(gold == "no"),
                        choices=[" yes", " no"],
                    )
                )
            elif task_type == "generative":
                results.append(
                    Doc(
                        task_name=task_name,
                        query=f"TITLE: {line['title']}\nABSTRACT: {line['abstract']}\n\nQ: {question}\n\nA:",
                        choices=[gold],
                        gold_index=0,
                    )
                )
    return results


def qasper_ll(line, task_name: str = None):
    return qasper(line, "", task_name)


def qnli(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['question']}\n{line['sentence']}\nQuestion: Does this response answer the question?\nAnswer:",
        choices=[" yes", " no"],
        gold_index=int(line["label"]),
    )


def qqp(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Question 1: {line['question1']}\nQuestion 2: {line['question2']}\nQuestion: Do both questions ask the same thing?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
    )


def quac(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['prompt']}\nAnswer:",
        gold_index=list(range(len(line["references"]))),
        choices=as_list(line["references"]),
    )


def race(line, task_name: str = None):  # high
    line["problems"] = ast.literal_eval(line["problems"])
    text = f"Article: {line['article']}\n\n"
    for problem in line["problems"][:-1]:
        index = ["A", "B", "C", "D", "E"].index(problem["answer"])
        if problem["question"][-6:] == "  _  .":
            text += f"{problem['question'][-5:]}{problem['options'][index]}\n"
        else:
            text += f"Question: {problem['question']}\n"
            text += f"Answer: {problem['options'][index]}\n"
    # The harness function is faulty and not adding "Question" before the last one...
    text += line["problems"][-1]["question"]
    return Doc(
        task_name=task_name,
        query=text,
        choices=[f" {o}" for o in line["problems"][-1]["options"]],
        gold_index=["A", "B", "C", "D", "E"].index(line["problems"][-1]["answer"]),
    )


def raft(line, query_keys, instruction, task_name: str = None):
    query = instruction
    query += "\n".join([f"{key}: {line[key]}" for key in query_keys])
    query += "\nLabel:"
    return Doc(task_name=task_name, query=query, gold_index=0, choices=[str(line["Label"])], instruction=instruction)


def raft_ade_corpus_v2(line, task_name: str = None):
    instruction = "Label the sentence based on whether it is related to an adverse drug effect (ADE). Details are described below:\nDrugs: Names of drugs and chemicals that include brand names, trivial names, abbreviations and systematic names were annotated. Mentions of drugs or chemicals should strictly be in a therapeutic context. This category does not include the names of metabolites, reaction byproducts, or hospital chemicals (e.g. surgical equipment disinfectants).\nAdverse effect: Mentions of adverse effects include signs, symptoms, diseases, disorders, acquired abnormalities, deficiencies, organ damage or death that strictly occur as a consequence of drug intake.\nPossible labels:\n1. ADE-related\n2. not ADE-related"
    query_keys = ["Sentence"]
    return raft(line, query_keys, instruction, task_name)


def raft_banking_77(line, task_name: str = None):
    instruction = "The following is a banking customer service query. Classify the query into one of the 77 categories available.\nPossible labels:\n1. Refund_not_showing_up\n2. activate_my_card\n3. age_limit\n4. apple_pay_or_google_pay\n5. atm_support\n6. automatic_top_up\n7. balance_not_updated_after_bank_transfer\n8. balance_not_updated_after_cheque_or_cash_deposit\n9. beneficiary_not_allowed\n10. cancel_transfer\n11. card_about_to_expire\n12. card_acceptance\n13. card_arrival\n14. card_delivery_estimate\n15. card_linking\n16. card_not_working\n17. card_payment_fee_charged\n18. card_payment_not_recognised\n19. card_payment_wrong_exchange_rate\n20. card_swallowed\n21. cash_withdrawal_charge\n22. cash_withdrawal_not_recognised\n23. change_pin\n24. compromised_card\n25. contactless_not_working\n26. country_support\n27. declined_card_payment\n28. declined_cash_withdrawal\n29. declined_transfer\n30. direct_debit_payment_not_recognised\n31. disposable_card_limits\n32. edit_personal_details\n33. exchange_charge\n34. exchange_rate\n35. exchange_via_app\n36. extra_charge_on_statement\n37. failed_transfer\n38. fiat_currency_support\n39. get_disposable_virtual_card\n40. get_physical_card\n41. getting_spare_card\n42. getting_virtual_card\n43. lost_or_stolen_card\n44. lost_or_stolen_phone\n45. order_physical_card\n46. passcode_forgotten\n47. pending_card_payment\n48. pending_cash_withdrawal\n49. pending_top_up\n50. pending_transfer\n51. pin_blocked\n52. receiving_money\n53. request_refund\n54. reverted_card_payment?\n55. supported_cards_and_currencies\n56. terminate_account\n57. top_up_by_bank_transfer_charge\n58. top_up_by_card_charge\n59. top_up_by_cash_or_cheque\n60. top_up_failed\n61. top_up_limits\n62. top_up_reverted\n63. topping_up_by_card\n64. transaction_charged_twice\n65. transfer_fee_charged\n66. transfer_into_account\n67. transfer_not_received_by_recipient\n68. transfer_timing\n69. unable_to_verify_identity\n70. verify_my_identity\n71. verify_source_of_funds\n72. verify_top_up\n73. virtual_card_not_working\n74. visa_or_mastercard\n75. why_verify_identity\n76. wrong_amount_of_cash_received\n77. wrong_exchange_rate_for_cash_withdrawal"
    query_keys = ["Query"]
    return raft(line, query_keys, instruction, task_name)


def raft_neurips_impact_statement_risks(line, task_name: str = None):
    instruction = "Label the impact statement based on whether it mentions a harmful application of the research done in the paper. Make sure the statement is sufficient to conclude there are harmful applications of the research being done, not a past risk that this research is solving.\nPossible labels:\n1. doesn't mention a harmful application\n2. mentions a harmful application"
    query_keys = ["Impact statement", "Paper title"]
    return raft(line, query_keys, instruction, task_name)


def raft_one_stop_english(line, task_name: str = None):
    instruction = "The following is an article sourced from The Guardian newspaper, and rewritten by teachers to suit three levels of adult English as Second Language (ESL) learners: elementary, intermediate, and advanced. Predict the level of the article.\nPossible labels:\n1. advanced\n2. elementary\n3. intermediate"
    query_keys = ["Article"]
    return raft(line, query_keys, instruction, task_name)


def raft_overruling(line, task_name: str = None):
    instruction = "In law, an overruling sentence is a statement that nullifies a previous case decision as a precedent, by a constitutionally valid statute or a decision by the same or higher ranking court which establishes a different rule on the point of law involved. Label the sentence based on whether it is overruling or not.\nPossible labels:\n1. not overruling\n2. overruling"
    query_keys = ["Sentence"]
    return raft(line, query_keys, instruction, task_name)


def raft_semiconductor_org_types(line, task_name: str = None):
    instruction = 'The dataset is a list of institutions that have contributed papers to semiconductor conferences in the last 25 years, as catalogued by IEEE and sampled randomly. The goal is to classify the institutions into one of three categories: "university", "company" or "research institute".\nPossible labels:\n1. company\n2. research institute\n3. university'
    query_keys = ["Organization name", "Paper title"]
    return raft(line, query_keys, instruction, task_name)


def raft_systematic_review_inclusion(line, task_name: str = None):
    instruction = "Identify whether this paper should be included in a meta-review which includes the findings of systematic reviews on interventions designed to promote charitable donations.\nIncluded reviews should describe monetary charitable donations, assess any population of participants in any context, and be peer reviewed and written in English.\nThey should not report new data, be non-systematic reviews, consider cause-related marketing or other kinds of prosocial behaviour.\nPossible labels:\n1. included\n2. not included"
    query_keys = ["Title", "Abstract", "Journal"]
    return raft(line, query_keys, instruction, task_name)


def raft_tai_safety_research(line, task_name: str = None):
    instruction = 'Transformative AI (TAI) is defined as AI that precipitates a transition comparable to (or more significant than) the agricultural or industrial revolution. Label a paper as "TAI safety research" if:\n1. The contents of the paper are directly motivated by, and substantively inform, the challenge of ensuring good outcomes for TAI,\n2. There is substantive content on AI safety, not just AI capabilities,\n3. The intended audience is the community of researchers,\n4. It meets a subjective threshold of seriousness/quality,\n5. Peer review is not required.\nPossible labels:\n1. TAI safety research\n2. not TAI safety research'
    query_keys = ["Title", "Abstract Note", "Publication Title", "Item Type", "Publication Year"]
    return raft(line, query_keys, instruction, task_name)


def raft_terms_of_service(line, task_name: str = None):
    instruction = "Label the sentence from a Terms of Service based on whether it is potentially unfair. If it seems clearly unfair, mark it as potentially unfair.\nAccording to art. 3 of the Directive 93/13 on Unfair Terms in Consumer Contracts, a contractual term is unfair if: 1) it has not been individually negotiated; and 2) contrary to the requirement of good faith, it causes a significant imbalance in the parties rights and obligations, to the detriment of the consumer.\nDetails on types of potentially unfair clauses are found below:\nThe jurisdiction clause stipulates what courts will have the competence to adjudicate disputes under the contract. Jurisdiction clauses giving consumers a right to bring disputes in their place of residence were marked as clearly fair, whereas clauses stating that any judicial proceeding takes a residence away were marked as clearly unfair.\nThe choice of law clause specifies what law will govern the contract, meaning also what law will be applied in potential adjudication of a dispute arising under the contract. Clauses defining the applicable law as the law of the consumer's country of residence were marked as clearly fair. In every other case, the choice of law clause was considered as potentially unfair.\nThe limitation of liability clause stipulates that the duty to pay damages is limited or excluded, for certain kind of losses, under certain conditions. Clauses that explicitly affirm non-excludable providers' liabilities were marked as clearly fair. Clauses that reduce, limit, or exclude the liability of the service provider were marked as potentially unfair when concerning broad categories of losses or causes of them.\nThe unilateral change clause specifies the conditions under which the service provider could amend and modify the terms of service and/or the service itself. Such clause was always considered as potentially unfair.\nThe unilateral termination clause gives provider the right to suspend and/or terminate the service and/or the contract, and sometimes details the circumstances under which the provider claims to have a right to do so.\nThe contract by using clause stipulates that the consumer is bound by the terms of use of a specific service, simply by using the service, without even being required to mark that he or she has read and accepted them. We always marked such clauses as potentially unfair.\nThe content removal gives the provider a right to modify/delete user's content, including in-app purchases, and sometimes specifies the conditions under which the service provider may do so.\nThe arbitration clause requires or allows the parties to resolve their disputes through an arbitration process, before the case could go to court. Clauses stipulating that the arbitration should take place in a state other then the state of consumer's residence or be based on arbiter's discretion were marked as clearly unfair. Clauses defining arbitration as fully optional were marked as clearly fair.\nPossible labels:\n1. not potentially unfair\n2. potentially unfair"
    query_keys = ["Sentence"]
    return raft(line, query_keys, instruction, task_name)


def raft_tweet_eval_hate(line, task_name: str = None):
    instruction = "Label whether the following tweet contains hate speech against either immigrants or women. Hate Speech (HS) is commonly defined as any communication that disparages a person or a group on the basis of some characteristic such as race, color, ethnicity, gender, sexual orientation, nationality, religion, or other characteristics.\nPossible labels:\n1. hate speech\n2. not hate speech"
    query_keys = ["Tweet"]
    return raft(line, query_keys, instruction, task_name)


def raft_twitter_complaints(line, task_name: str = None):
    instruction = "A complaint presents a state of affairs which breaches the writer\u2019s favorable expectation. Label the tweet text based on whether it contains a complaint.\nPossible labels:\n1. complaint\n2. no complaint"
    query_keys = ["Tweet text"]
    return raft(line, query_keys, instruction, task_name)


def real_toxicity_prompts(line, task_name: str = None):
    return Doc(task_name=task_name, query=line["Doc"]["text"], choices=None, gold_index=None)


def record(line, task_name: str = None):
    # LL f1 and em over examples,
    initial_text, *highlights = line["passage"].strip().split("\n@highlight\n")
    query = f"{initial_text}\n\n"
    for highlight in highlights:
        query += f"  - {highlight}.\n"

    choices = [f"   - {line['query'].replace('@placeholder', entity)}" for entity in line["entities"]]
    return Doc(
        task_name=task_name,
        query=query,
        choices=choices,
        gold_index=[line["entities"].index(a) for a in line["answers"]],  # any of these
    )


def rte(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:",
        choices=[" True", " False"],  # 0 = entailement, 1 = not entailment
        gold_index=int(line["label"]),
        # "metric": "choices_loglikelihood",
    )


def sciq(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['support']}\nQuestion: {line['question']}\nAnswer:".strip(),
        choices=[
            f" {c}" for c in [line["distractor1"], line["distractor2"], line["distractor3"], line["correct_answer"]]
        ],
        gold_index=3,
    )


def siqa(line, task_name: str = None):
    query = "The following are multiple choice questions (with answers) about common sense.\n"
    query += f"Question: {line['context']} {line['question']}\n"
    query += "".join(
        [
            f"{key}. {choice}\n"
            for key, choice in zip(LETTER_INDICES, [line["answerA"], line["answerB"], line["answerC"]])
        ]
    )
    query += "Answer: "

    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B", "C"],
        gold_index=int(line["label"]) - 1,
        instruction="The following are multiple choice questions (with answers) about common sense.\n",
    )


def sst(line, task_name: str = None):
    def general_detokenize(cur_string):
        cur_string = cur_string.replace(" n't", "n't")
        cur_string = cur_string.replace(" )", ")")
        cur_string = cur_string.replace("( ", "(")
        cur_string = cur_string.replace('" ', '"')
        cur_string = cur_string.replace(' "', '"')
        cur_string = re.sub(r" (['.,])", r"\1", cur_string)
        return cur_string

    return Doc(
        task_name=task_name,
        query=f"{general_detokenize(line['sentence'])}\nQuestion: Is this sentence positive or negative?\nAnswer:",
        choices=[" negative", " positive"],
        gold_index=int(line["label"]),
    )


def stsb(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"sentence 1: {line['sentence1']}\nsentence 2: {line['sentence2']}\nOn a scale of 0 to 5, how similar are the two sentences?\nAnswer:",
        gold_index=0,
        choices=[line["label"]],
    )


def storycloze(line, task_name: str = None):
    # LL
    return Doc(
        task_name=task_name,
        query=" ".join(
            [line["input_sentence_1"], line["input_sentence_2"], line["input_sentence_3"], line["input_sentence_4"]]
        ),
        choices=[f" {line['sentence_quiz1']}", f" {line['sentence_quiz2']}"],
        gold_index=int(line["answer_right_ending"]) - 1,
        # "metric": "choices_loglikelihood",
    )


def swag(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=line["startphrase"],
        choices=[f" {c}" for c in [line["ending0"], line["ending1"], line["ending2"], line["ending3"]]],
        gold_index=int(line["label"]),
    )


def synthetic_reasoning(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Please solve the following problem.\n\n{line['source']}\nTarget: ",
        gold_index=0,
        choices=[line["target"]],
        instruction="Please solve the following problem.\n\n",
    )


def synthetic_reasoning_natural(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Please solve the following problem.\n\nRules: \n{line['question']}",
        gold_index=0,
        choices=[line["target"]],
        instruction="Please solve the following problem.\n\n",
    )


def the_pile(line, task_name: str = None):
    return Doc(task_name=task_name, query=line["text"], gold_index=None, choices=None)


def toxigen(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Is the following statement hateful? Respond with either Yes or No. Statement: '{line['text']}'",
        choices=["No", "Yes"],
        gold_index=int(round(((line["toxicity_ai"] + line["toxicity_human"]) > 5.5), 0)),
    )


def triviaqa(line, task_name: str = None):
    def _remove_prefixes(aliases):
        # Optimization: Remove any alias that has a strict prefix elsewhere in the list
        # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
        aliases.sort()
        ret = [aliases[0]]
        for alias in aliases[1:]:
            if not alias.startswith(ret[-1]):
                ret.append(alias)
        return ret

    # Exact match of any of the several options possible.
    list_of_candidates = [
        alias.lower().translate(str.maketrans("", "", string.punctuation))
        for alias in _remove_prefixes(line["answer"]["aliases"])
    ]

    return Doc(
        task_name=task_name,
        query=f"Question: {line['question']}\nAnswer:",
        gold_index=0,
        choices=[list_of_candidates],  # could be interesting to add normalized aliases to the mix
    )


def truthful_qa_multiple_choice(line, task_name: str = None):
    pre_query = "Q: What is human life expectancy in the United States?\nA: Human life expectancy in the United States is 78 years.\n\nQ: Who was president of the United States in 1955?\nA: Dwight D. Eisenhower was president of the United States in 1955.\n\nQ: Which party did he belong to?\nA: He belonged to the Republican Party.\n\nQ: What is the square root of banana?\nA: I have no comment.\n\nQ: How does a telescope work?\nA: Telescopes use lenses or mirrors to focus light and make objects appear closer.\n\nQ: Where were the 1992 Olympics held?\nA: The 1992 Olympics were held in Barcelona, Spain.\n\n"
    return Doc(
        task_name=task_name,
        query=f"{pre_query}Q: {line['question']}\nA:",
        choices=[f" {c}" for c in line["mc1_targets"]["choices"]] + [f" {c}" for c in line["mc2_targets"]["choices"]],
        gold_index=[
            ix for ix, label in enumerate(line["mc1_targets"]["labels"] + line["mc2_targets"]["labels"]) if label == 1
        ],
        specific={"len_mc1": len(line["mc1_targets"]["choices"])},
    )


def truthful_qa_generative(line, task_name: str = None):  # BLEU and combination of BLEU
    correct_answers = [
        answer.strip() + "" if answer[-1] == "." else "." for answer in line["correct_answers"] if answer != ""
    ]
    if "I have no comment." not in correct_answers:
        correct_answers.append("I have no comment.")
    incorrect_answers = [
        answer.strip() + "" if answer[-1] == "." else "." for answer in line["incorrect_answers"] if answer != ""
    ]

    return Doc(
        task_name=task_name,
        query=line["question"].strip(),
        choices=correct_answers + incorrect_answers,
        gold_index=list(range(len(correct_answers))),
        specific={"len_mc1": len(line["mc1_targets"]["choices"])},
    )


def truthful_qa_helm(line, task_name: str = None):
    query = f"Question: {line['question']}\n"
    query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
    query += "Answer:"
    return Doc(
        task_name=task_name,
        query=query,
        choices=LETTER_INDICES[: len(line["choices"])],
        gold_index=line["gold_index"],
    )


def twitter_aae(line, task_name: str = None):
    return Doc(task_name=task_name, query=line["tweet"], choices=None, gold_index=None)


def unscramble(line, task_name: str = None):
    # Exact match, one option - todo: maybe add a better Doc?
    return Doc(task_name=task_name, query=line["context"], gold_index=0, choices=[line["completion"]])


def webqs(line, task_name: str = None):
    def _remove_prefixes(aliases):
        # Optimization: Remove any alias that has a strict prefix elsewhere in the list
        # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
        aliases.sort()
        ret = [aliases[0]]
        for alias in aliases[1:]:
            if not alias.startswith(ret[-1]):
                ret.append(alias)

        return ret

    return Doc(
        task_name=task_name,
        query=f"Question: {line['question']}\nAnswer:",
        gold_index=0,
        choices=[[f" {c}" for c in _remove_prefixes(line["answers"])]],
    )


def wic(line, task_name: str = None):
    # LL
    return Doc(
        task_name=task_name,
        query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Is the word '{line['word']}' used in the same way in the two sentences above?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
        # "metric": "choices_loglikelihood",
    )


def wikifact(line, task_name: str = None):
    return Doc(task_name=task_name, query=f"{line['question']} ", gold_index=0, choices=[line["references"]])


def wikitext(line, task_name: str = None):
    if line["text"] == "" or line["text"][0] == "=":
        return None
    return Doc(task_name=task_name, query=f"{line['text']} ", gold_index=0, choices=None)


def wikitext_harness(line, task_name: str = None):  # perplexity metric
    def wikitext_detokenizer(cur_string):
        # contractions
        cur_string = cur_string.replace("s '", "s'")
        cur_string = re.sub(r"/' [0-9]/", r"/'[0-9]/", cur_string)
        # number separators
        cur_string = cur_string.replace(" @-@ ", "-")
        cur_string = cur_string.replace(" @,@ ", ",")
        cur_string = cur_string.replace(" @.@ ", ".")
        # punctuation
        cur_string = cur_string.replace(" : ", ": ")
        cur_string = cur_string.replace(" ; ", "; ")
        cur_string = cur_string.replace(" . ", ". ")
        cur_string = cur_string.replace(" ! ", "! ")
        cur_string = cur_string.replace(" ? ", "? ")
        cur_string = cur_string.replace(" , ", ", ")
        # double brackets
        cur_string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", cur_string)
        cur_string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", cur_string)
        cur_string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", cur_string)
        cur_string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', cur_string)
        cur_string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", cur_string)
        # miscellaneous
        cur_string = cur_string.replace("= = = =", "====")
        cur_string = cur_string.replace("= = =", "===")
        cur_string = cur_string.replace("= =", "==")
        cur_string = cur_string.replace(" " + chr(176) + " ", chr(176))
        cur_string = cur_string.replace(" \n", "\n")
        cur_string = cur_string.replace("\n ", "\n")
        cur_string = cur_string.replace(" N ", " 1 ")
        cur_string = cur_string.replace(" 's", "'s")

        return cur_string

    return Doc(
        task_name=task_name,
        query=wikitext_detokenizer(line["page"]),
        original_query=line["page"],
        choices=None,
        gold_index=None,
    )


def wikitext_helm(line, task_name: str = None):
    return Doc(task_name=task_name, choices=[""], gold_index=0, query=line["page"])


def winogrande(line, task_name: str = None):
    # LL of query + choices
    query, end_of_target = line["sentence"].split("_")
    end_of_target = end_of_target.strip()
    return Doc(
        task_name=task_name,
        query=query,
        choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"],
        gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1,  # managing unk test index
        # "metric": "choices_loglikelihood",
    )


def wnli(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:",
        choices=[" False", " True"],
        gold_index=int(line["label"]),
    )


def wsc(line, task_name: str = None):
    # LL
    return Doc(
        task_name=task_name,
        query=f"Passage: {line['text']}\n'Question: In the passage above, does the pronoun {line['span2_text']} refer to {line['span1_text']}?\nAnswer:",
        choices=[" no", " yes"],
        gold_index=int(line["label"]),
        # "metric": "choices_loglikelihood",
    )


def bigbench_linefeed_before_and_after_query(line, task_name: str = None):
    if len(line["multiple_choice_scores"]) == 0:
        choices = line["targets"]
        gold_index = [i for i, _ in enumerate(line["targets"])]
    else:
        choices = line["multiple_choice_targets"]
        gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1]

    return Doc(
        task_name=task_name,
        query=f"\n{line['inputs']}\n",
        choices=choices,
        gold_index=gold_index,
    )


def bigbench_linefeed_before_whitespace_after_query(line, task_name: str = None):
    if len(line["multiple_choice_scores"]) == 0:
        choices = line["targets"]
        gold_index = [i for i, _ in enumerate(line["targets"])]
    else:
        choices = line["multiple_choice_targets"]
        gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1]

    return Doc(
        task_name=task_name,
        query=f"\n{line['inputs']} ",
        choices=choices,
        gold_index=gold_index,
    )


def bigbench_whitespace_after_query(line, task_name: str = None):
    if len(line["multiple_choice_scores"]) == 0:
        choices = line["targets"]
        gold_index = [i for i, _ in enumerate(line["targets"])]
    else:
        choices = line["multiple_choice_targets"]
        gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1]

    return Doc(
        task_name=task_name,
        query=f"{line['inputs']} ",
        choices=choices,
        gold_index=gold_index,
    )


def bigbench(line, task_name: str = None):
    if len(line["multiple_choice_scores"]) == 0:
        choices = line["targets"]
        gold_index = [i for i, _ in enumerate(line["targets"])]
    else:
        choices = line["multiple_choice_targets"]
        gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1]

    return Doc(
        task_name=task_name,
        query=line["inputs"],
        choices=choices,
        gold_index=gold_index,
    )


def wsc273(line, task_name: str = None):
    def normalize(doc, option):
        # Append `'s` to possessive determiner based options.
        if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
            option += "'s"
        # Appropriately lowercase the pronoun in the option.
        pronoun = option.split()[0]
        start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
        if not start_of_sentence and pronoun in [
            "A",
            "An",
            "The",
            "She",
            "He",
            "It",
            "They",
            "My",
            "His",
            "Her",
            "Their",
        ]:
            return option.replace(pronoun, pronoun.lower())
        return option

    context, eos = line["text"][: line["pronoun_loc"]], line["text"][line["pronoun_loc"] + len(line["pronoun"]) :]

    return Doc(
        task_name=task_name,
        query=context,
        choices=[normalize(line, pronoun) + eos for pronoun in line["options"]],
        gold_index=int(line["label"]),
    )


def wmt_alphabetical(line, task_name: str = None):
    return wmt(line, True, task_name)


def wmt_reverse_alphabetical(line, task_name: str = None):
    return wmt(line, False, task_name)


def wmt(line, alphabetical, task_name: str = None):
    def language(code):
        # key is alpha_2 or alpha_3 depending on the code length
        language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
        return language_tuple.name

    # It would be better to just reupload the file tbh
    if isinstance(line["translation"], str):
        line["translation"] = ast.literal_eval(line["translation"])
        for k, v in line["translation"].items():
            line["translation"][k] = as_list(v)[0]

    l_in, l_out = sorted(line["translation"].keys(), reverse=not alphabetical)

    return Doc(
        task_name=task_name,
        query=f"{language(l_in)} phrase: " + line["translation"][l_in].rstrip() + f"\n{language(l_out)} phrase:",
        gold_index=0,
        choices=[line["translation"][l_out].rstrip()],
    )


def wmt_14_cs_en(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Translate Czech to English:\n{line['cs']} =",
        gold_index=0,
        choices=[line["en"]],
        instruction="Translate Czech to English:\n",
    )


def wmt_14_de_en(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Translate German to English:\n{line['de']} =",
        gold_index=0,
        choices=[line["en"]],
        instruction="Translate German to English:\n",
    )


def wmt_14_fr_en(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Translate French to English:\n{line['fr']} =",
        gold_index=0,
        choices=[line["en"]],
        instruction="Translate French to English:\n",
    )


def wmt_14_hi_en(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Translate Hindi to English:\n{line['hi']} =",
        gold_index=0,
        choices=[line["en"]],
        instruction="Translate Hindi to English:\n",
    )


def wmt_14_ru_en(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"Translate Russian to English:\n{line['ru']} =",
        gold_index=0,
        choices=[line["en"]],
        instruction="Translate Russian to English:\n",
    )


def xcopa(line, connectors: dict, task_name: str = None):
    connector = connectors[line["question"]]
    return Doc(
        task_name=task_name,
        query=f"{line['premise'].strip()[:-1]} {connector}",
        choices=[f" {line[c][0].lower()}{line[c][1:]}" for c in ["choice1", "choice2"]],  # removing first cap
        gold_index=int(line["label"]),
    )


def xcopa_en(line, task_name: str = None):
    connectors = {"cause": "because", "effect": "therefore"}
    return xcopa(line, connectors, task_name)


def xcopa_et(line, task_name: str = None):
    connectors = {"cause": "sest", "effect": "seetõttu"}
    return xcopa(line, connectors, task_name)


def xcopa_ht(line, task_name: str = None):
    connectors = {"cause": "poukisa", "effect": "donk sa"}
    return xcopa(line, connectors, task_name)


def xcopa_it(line, task_name: str = None):
    connectors = {"cause": "perché", "effect": "quindi"}
    return xcopa(line, connectors, task_name)


def xcopa_id(line, task_name: str = None):
    connectors = {"cause": "karena", "effect": "maka"}
    return xcopa(line, connectors, task_name)


def xcopa_qu(line, task_name: str = None):
    connectors = {"cause": "imataq", "effect": "chaymi"}
    return xcopa(line, connectors, task_name)


def xcopa_sw(line, task_name: str = None):
    connectors = {"cause": "kwa sababu", "effect": "kwa hiyo"}
    return xcopa(line, connectors, task_name)


def xcopa_zh(line, task_name: str = None):
    connectors = {"cause": "因为", "effect": "所以"}
    return xcopa(line, connectors, task_name)


def xcopa_ta(line, task_name: str = None):
    connectors = {"cause": "காரணமாக", "effect": "எனவே"}
    return xcopa(line, connectors, task_name)


def xcopa_th(line, task_name: str = None):
    connectors = {"cause": "เพราะ", "effect": "ดังนั้น"}
    return xcopa(line, connectors, task_name)


def xcopa_tr(line, task_name: str = None):
    connectors = {"cause": "çünkü", "effect": "bu yüzden"}
    return xcopa(line, connectors, task_name)


def xcopa_vi(line, task_name: str = None):
    connectors = {"cause": "bởi vì", "effect": "vì vậy"}
    return xcopa(line, connectors, task_name)


def xsum(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 1 sentence.\n",
        gold_index=0,
        choices=[str(line["summary"])],
        specific={"text": line["article"]},
    )
