import abc
import logging
import re
from typing import Any, Dict, List, Optional, Sequence, Union

import datasets
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torchmetrics import Metric

from olmo.util import load_hf_dataset, load_oe_eval_requests

from ..tokenizer import Tokenizer

log = logging.getLogger(__name__)


METRIC_FROM_OE_EVAL = {
    "acc_raw": "acc",
    "acc_per_char": "len_norm",
    "acc_uncond": "pmi_dc",
    "logits_per_byte": "bpb",
}
LOG_2_OF_E = 1.44269504089


class ICLMetric(Metric):
    
    full_state_update: bool = False

    def __init__(self, metric_type="acc") -> None:
        
        super().__init__(sync_on_compute=True)

        self.metric_type = metric_type

        self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
        self.add_state("labels", default=[], dist_reduce_fx=None)

    def reset(
        self,
    ):
        self.loglikelihoods = []
        self.labels = []

    def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
        lm_logits = F.log_softmax(lm_logits, dim=-1)

        if self.metric_type == "pmi_dc":
            assert dc_lm_logits is not None, "PMI_DC acc type selected but no domain conditional logits provided"

        for idx, (doc_id, cont_id) in enumerate(zip(batch["doc_id"], batch["cont_id"])):
            
            cont_tokens = batch["continuation"][idx][: batch["cont_len"][idx]]
            
            
            
            lm_cont_logits = lm_logits[idx][
                batch["ctx_len"][idx] - 1 : batch["ctx_len"][idx] + batch["cont_len"][idx] - 1
            ]

            log_likelihood: torch.Tensor
            if self.metric_type == "pmi_dc":
                assert dc_lm_logits is not None
                
                dc_lm_cont_logits = dc_lm_logits[idx][
                    batch["dc_len"][idx] - 1 : batch["dc_len"][idx] + batch["cont_len"][idx] - 1
                ]

                
                log_likelihood = (
                    torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
                    / torch.gather(dc_lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
                )
            elif self.metric_type == "acc" or self.metric_type == "f1":
                
                log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
            elif self.metric_type == "len_norm" or self.metric_type == "ce_loss":
                log_likelihood = (
                    torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx]
                )
                if self.metric_type == "ce_loss":
                    log_likelihood = -log_likelihood
            elif self.metric_type == "bpb":
                
                log_likelihood = (
                    -torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
                    / batch["cont_byte_len"][idx]
                    * LOG_2_OF_E
                )
            else:
                raise ValueError(self.metric_type)

            
            self.loglikelihoods.append(
                torch.Tensor((doc_id, cont_id, log_likelihood)).to(batch["continuation"][idx].device)
            )
            self.labels.append(
                torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(batch["label_id"][idx].device)
            )

    def compute(self) -> torch.Tensor:
        
        
        loglikelihood_dict: Dict[int, Dict[int, float]] = {}
        label_dict = {}

        
        for doc_id, cont_id, label_id in self.labels:
            if doc_id.item() not in label_dict:
                label_dict[doc_id.item()] = label_id.item()

        
        for doc_id, cont_id, loglikelihood in self.loglikelihoods:
            if int(doc_id.item()) not in loglikelihood_dict:
                loglikelihood_dict[int(doc_id.item())] = {}

            if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
                loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood

        
        correct = []
        preds: Optional[List[float]] = None
        labels: Optional[List[int]] = None
        if self.metric_type == "f1":
            preds = []
            labels = []

        for doc_id in loglikelihood_dict:
            
            num_continuations = len(loglikelihood_dict[doc_id].keys())
            loglikelihoods = torch.tensor([-float("inf")] * num_continuations)

            skip_document = False
            for cont_id in loglikelihood_dict[doc_id]:
                try:
                    loglikelihoods[cont_id] = loglikelihood_dict[doc_id][cont_id]
                except IndexError:
                    
                    skip_document = True
                    break

            if skip_document:
                continue
            if self.metric_type in ["ce_loss", "bpb"]:
                correct.append(loglikelihoods[0])  
            else:
                correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)

            if self.metric_type == "f1":
                assert preds is not None
                assert labels is not None
                preds.append(torch.argmax(loglikelihoods).item())
                labels.append(label_dict[doc_id])

        if self.metric_type == "f1":
            assert preds is not None
            assert labels is not None
            
            score = f1_score(labels, preds, pos_label=0)
        else:
            score = sum(correct) / len(correct)

        return torch.tensor(score)


class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
    

    metric_type: str

    def __init__(
        self,
        tokenizer: Tokenizer,
        dataset_path: str,
        dataset_name: Union[str, Sequence[str], None] = None,
        model_ctx_len: int = 2048,
        split="validation",
        metric_type=None,  
        prompts=[None],  
    ):
        super().__init__()

        self.tokenizer = tokenizer
        self.dataset_path = dataset_path
        self.dataset_name = dataset_name
        self.model_ctx_len = model_ctx_len
        self.prompts = prompts
        self.current_prompt = None
        if metric_type is not None:
            self.metric_type = metric_type
        self.log_instances = 0  

        self.samples: List[Dict[str, Any]] = []
        dataset_names: Sequence[Optional[str]]
        if isinstance(dataset_name, str) or dataset_name is None:
            dataset_names = [dataset_name]
        else:
            dataset_names = dataset_name

        dataset_list = []
        for ds_name in dataset_names:
            dataset = load_hf_dataset(self.dataset_path, ds_name, split)
            dataset_list.append(dataset)
        self.dataset = datasets.concatenate_datasets(dataset_list)

        
        self.prep_examples()

    def __getitem__(self, index):
        return self.samples[index]

    def __len__(self):
        return len(self.samples)

    def prep_examples(self):
        
        doc_id = 0
        for doc in self.dataset:
            for prompt in self.prompts:
                self.current_prompt = prompt
                
                
                
                
                
                
                

                continuations = self.doc_to_continuations(doc)
                label_id = self.doc_to_label(doc)
                doc_text = self.doc_to_text(doc)
                ctx = self.token_encode(doc_text)
                dc = self.token_encode(self.doc_to_domain_conditional(doc))
                if self.log_instances > 0:
                    self.log_instances -= 1
                    ds_name = self.dataset_name
                    if isinstance(ds_name, list):
                        ds_name = ds_name[0]
                    log.info(
                        f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):"
                        + f"\ndoc_text: {doc_text}\ncontinuations: {continuations}"
                    )

                for cont_id, continuation_str in enumerate(continuations):
                    cont_str_len = len(continuation_str) - 1  
                    cont_byte_len = len(continuation_str[1:].encode("utf-8"))
                    continuation = self.token_encode(continuation_str)

                    
                    query = ctx + continuation[:-1]
                    query = query[-self.model_ctx_len :]
                    
                    actual_ctx_len = len(query) - len(continuation) + 1

                    
                    
                    dc_query = dc + continuation[:-1]

                    
                    self.samples.append(
                        {
                            "doc_id": doc_id,
                            "cont_id": cont_id,
                            "ctx": ctx,
                            "continuation": continuation,
                            "ctx_len": actual_ctx_len,
                            "dc_len": len(dc),
                            "cont_len": len(
                                continuation
                            ),  
                            "cont_str_len": cont_str_len,
                            "cont_byte_len": cont_byte_len,
                            "query": query,  
                            "dc_query": dc_query,
                            "label_id": label_id,
                        }
                    )

                doc_id += 1

    def pad_tokens_until_max(self, tokens, max_len=2048):
        
        if len(tokens) > self.model_ctx_len:
            return tokens[-self.model_ctx_len :]
        else:
            
            
            tokens = tokens + [self.tokenizer.pad_token_id] * (max_len - len(tokens))

            if len(tokens) > self.model_ctx_len:
                tokens = tokens[: self.model_ctx_len]

            return tokens

    def collate_fn(self, data):
        
        
        max_ctx_len = 0
        max_cont_len = 0
        max_query_len = 0
        max_dc_query_len = 0

        for sample in data:
            if len(sample["ctx"]) > max_ctx_len:
                max_ctx_len = len(sample["ctx"])

            if len(sample["continuation"]) > max_cont_len:
                max_cont_len = len(sample["continuation"])

            if len(sample["query"]) > max_query_len:
                max_query_len = len(sample["query"])

            if len(sample["dc_query"]) > max_dc_query_len:
                max_dc_query_len = len(sample["dc_query"])

        doc_ids = []
        cont_ids = []
        ctxs = []
        continuations = []
        ctx_lens = []
        dc_lens = []
        cont_lens = []
        cont_str_lens = []
        cont_byte_lens = []
        queries = []
        dc_queries = []
        label_ids = []

        
        for sample in data:
            doc_ids.append(sample["doc_id"])
            cont_ids.append(sample["cont_id"])

            ctxs.append(torch.LongTensor(self.pad_tokens_until_max(sample["ctx"], max_len=max_ctx_len)))
            continuations.append(
                torch.LongTensor(self.pad_tokens_until_max(sample["continuation"], max_len=max_cont_len))
            )

            ctx_lens.append(sample["ctx_len"])
            dc_lens.append(sample["dc_len"])
            cont_lens.append(sample["cont_len"])
            cont_str_lens.append(sample["cont_str_len"])
            cont_byte_lens.append(sample["cont_byte_len"])

            queries.append(torch.LongTensor(self.pad_tokens_until_max(sample["query"], max_len=max_query_len)))
            dc_queries.append(
                torch.LongTensor(self.pad_tokens_until_max(sample["dc_query"], max_len=max_dc_query_len))
            )

            label_ids.append(sample["label_id"])

        batch = {
            "doc_id": torch.LongTensor(doc_ids),
            "cont_id": torch.LongTensor(cont_ids),
            "ctx": torch.stack(ctxs),
            "continuation": torch.stack(continuations),
            "ctx_len": torch.LongTensor(ctx_lens),
            "dc_len": torch.LongTensor(dc_lens),
            "cont_len": torch.LongTensor(cont_lens),  
            "cont_str_len": torch.LongTensor(cont_str_lens),
            "cont_byte_len": torch.LongTensor(cont_byte_lens),
            "input_ids": torch.stack(queries),
            "dc_input_ids": torch.stack(dc_queries),
        }

        if not isinstance(label_ids, str):
            batch["label_id"] = torch.LongTensor(label_ids)

        return batch

    def token_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string, add_special_tokens=False)

    def token_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)

    @abc.abstractmethod
    def doc_to_text(self, doc) -> str:
        
        raise NotImplementedError

    @abc.abstractmethod
    def doc_to_continuations(self, doc) -> List[str]:
        
        raise NotImplementedError

    @abc.abstractmethod
    def doc_to_label(self, doc) -> int:
        
        raise NotImplementedError

    def doc_to_domain_conditional(self, doc) -> str:
        
        del doc
        return " "


class PIQA(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="piqa",
        dataset_name="plain_text",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return "Question: " + doc["goal"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        
        return [" " + doc["sol1"], " " + doc["sol2"]]

    def doc_to_label(self, doc):
        return doc["label"]

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class HellaSwag(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="hellaswag",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    @classmethod
    def preprocess(cls, text):
        text = text.strip()
        
        text = text.replace(" [title]", ". ")
        text = re.sub("\\[.*?\\]", "", text)
        text = text.replace("  ", " ")

        return text

    def doc_to_text(self, doc):
        return self.preprocess(doc["activity_label"] + ": " + doc["ctx_a"] + " " + doc["ctx_b"].capitalize())

    def doc_to_continuations(self, doc):
        
        return [" " + self.preprocess(ending) for ending in doc["endings"]]

    def doc_to_label(self, doc):
        return int(doc["label"])

    def doc_to_domain_conditional(self, doc):
        domain_conditional = self.preprocess(doc["ctx_b"].capitalize())

        
        if len(domain_conditional) == 0:
            return self.preprocess(doc["ctx_a"]).split(" ")[-1]

        return domain_conditional


class WinoGrande(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="winogrande",
        dataset_name="winogrande_xl",
    ):
        
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def prep_examples(self):
        
        doc_id = 0
        for doc in self.dataset:
            
            ctxs = self.doc_to_text(doc)
            dcs = self.doc_to_domain_conditional(doc)

            continuation_str = self.doc_to_continuations(doc)
            label_id = self.doc_to_label(doc)
            cont_str_len = len(continuation_str) - 1  
            cont_byte_len = len(continuation_str[1:].encode("utf-8"))

            
            continuation = self.token_encode(continuation_str)

            for cont_id, (ctx, dc) in enumerate(zip(ctxs, dcs)):
                ctx = self.token_encode(ctx)
                dc = self.token_encode(dc)

                
                query = ctx + continuation[:-1]
                query = query[-self.model_ctx_len :]

                
                
                dc_query = dc + continuation[:-1]

                
                self.samples.append(
                    {
                        "doc_id": doc_id,
                        "cont_id": cont_id,
                        "ctx": ctx,
                        "continuation": continuation,
                        "ctx_len": len(ctx),
                        "dc_len": len(dc),
                        "cont_len": len(
                            continuation
                        ),  
                        "cont_str_len": cont_str_len,
                        "cont_byte_len": cont_byte_len,
                        "query": query,  
                        "dc_query": dc_query,
                        "label_id": label_id,
                    }
                )

            doc_id += 1

    def doc_to_text(self, doc):
        
        pronoun_loc = doc["sentence"].index("_")

        ctx = []
        for option in [doc["option1"], doc["option2"]]:
            ctx.append(doc["sentence"][:pronoun_loc] + option)

        return ctx

    def doc_to_continuations(self, doc):
        
        pronoun_loc = doc["sentence"].index("_") + 1
        return " " + doc["sentence"][pronoun_loc:].strip()

    def doc_to_label(self, doc):
        return int(doc["answer"]) - 1

    def doc_to_domain_conditional(self, doc):
        
        return [doc["option1"], doc["option2"]]


class OpenBookQA(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="openbookqa",
        dataset_name="main",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return doc["question_stem"]

    def doc_to_continuations(self, doc):
        
        return [" " + choice for choice in doc["choices"]["text"]]

    def doc_to_label(self, doc):
        return ["A", "B", "C", "D"].index(doc["answerKey"].strip())

    def doc_to_domain_conditional(self, doc):
        return doc["question_stem"].strip().split(" ")[-1]


class BoolQ(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="boolq",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return doc["passage"] + "\nQuestion: " + doc["question"] + "?\nAnswer:"

    def doc_to_continuations(self, doc):
        del doc
        
        return [" yes", " no"]

    def doc_to_label(self, doc):
        
        if doc["answer"]:
            return 0
        else:
            return 1

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class SciQ(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="sciq",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        
        return [
            " " + doc["distractor1"],
            " " + doc["distractor2"],
            " " + doc["distractor3"],
            " " + doc["correct_answer"],
        ]

    def doc_to_label(self, doc):
        del doc
        return 3

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class ArcEasy(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="ai2_arc",
        dataset_name="ARC-Easy",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return "Question: " + doc["question"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        
        return [" " + choice for choice in doc["choices"]["text"]]

    def doc_to_label(self, doc):
        
        num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}

        if doc["answerKey"] in num_to_letter:
            doc["answerKey"] = num_to_letter[doc["answerKey"]]

        return ["A", "B", "C", "D", "E"].index(doc["answerKey"])

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class ArcChallenge(ArcEasy):
    

    metric_type = "len_norm"  

    def __init__(
        self,
        tokenizer,
        dataset_path="ai2_arc",
        dataset_name="ARC-Challenge",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )


class ArcEasyCELoss(ArcEasy):
    

    metric_type = "ce_loss"

    def doc_to_continuations(self, doc):
        
        answer = doc["choices"]["text"][self.doc_to_label(doc)]
        return [" " + answer]

    def doc_to_label(self, doc):
        return 0


class BasicArithmetic(ArcEasy):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="allenai/basic_arithmetic",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )


class CommonsenseQA(ArcEasy):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="tau/commonsense_qa",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )


class SocialIQa(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="social_i_qa",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return "Question: " + doc["context"] + " " + doc["question"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        
        return [
            " " + doc["answerA"],
            " " + doc["answerB"],
            " " + doc["answerC"],
        ]

    def doc_to_label(self, doc):
        return int(doc["label"]) - 1

    def doc_to_domain_conditional(self, doc):
        return "Answer:"


class COPA(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="super_glue",
        dataset_name="copa",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        connector = "because" if doc["question"] == "cause" else "therefore"

        
        return doc["premise"].strip()[:-1] + " " + connector

    def doc_to_continuations(self, doc):
        
        def convert_choice(choice):
            return choice[0].lower() + choice[1:]

        return [" " + convert_choice(doc["choice1"]), " " + convert_choice(doc["choice2"])]

    def doc_to_label(self, doc):
        return doc["label"]

    def doc_to_domain_conditional(self, doc):
        return "because" if doc["question"] == "cause" else "therefore"


class RTE(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"

    def __init__(
        self,
        tokenizer,
        dataset_path="glue",
        dataset_name="rte",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return doc["sentence1"] + "\nQuestion: " + doc["sentence2"] + " True or False?\nAnswer:"

    def doc_to_continuations(self, doc):
        del doc
        
        return [" True", " False"]

    def doc_to_label(self, doc):
        return doc["label"]

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class CommitmentBank(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="super_glue",
        dataset_name="cb",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return doc["premise"] + "\nQuestion: " + doc["hypothesis"] + ". True, False or Neither?\nAnswer:"

    def doc_to_continuations(self, doc):
        del doc
        
        return [" True", " False", " Neither"]

    def doc_to_label(self, doc):
        return doc["label"]

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class MRPC(ICLMultiChoiceTaskDataset):
    

    metric_type = "f1"

    def __init__(
        self,
        tokenizer,
        dataset_path="glue",
        dataset_name="mrpc",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    @classmethod
    def preprocess(cls, string: str) -> str:
        string = string.replace(" n't", "n't")
        string = string.replace(" )", ")")
        string = string.replace("( ", "(")
        string = string.replace('" ', '"')
        string = string.replace(' "', '"')

        string = re.sub(r" (['.,])", r"\1", string)

        return string

    def doc_to_text(self, doc):
        return (
            "Sentence 1: "
            + self.preprocess(doc["sentence1"])
            + "\nSentence 2: "
            + self.preprocess(doc["sentence2"])
            + "\nQuestion: Do both sentences mean the same thing?\nAnswer:"
        )

    def doc_to_continuations(self, doc):
        del doc
        
        return [" yes", " no"]

    def doc_to_label(self, doc):
        
        if doc["label"]:
            return 0
        else:
            return 1

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class SST2(ICLMultiChoiceTaskDataset):
    

    metric_type = "acc"

    def __init__(
        self,
        tokenizer,
        dataset_path="glue",
        dataset_name="sst2",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    @classmethod
    def preprocess(cls, string: str) -> str:
        string = string.replace(" n't", "n't")
        string = string.replace(" )", ")")
        string = string.replace("( ", "(")
        string = string.replace('" ', '"')
        string = string.replace(' "', '"')

        string = re.sub(r" (['.,])", r"\1", string)

        return string

    def doc_to_text(self, doc):
        return self.preprocess(doc["sentence"]) + "\nQuestion: Is this sentence positive or negative?\nAnswer:"

    def doc_to_continuations(self, doc):
        del doc
        
        
        return [" negative", " positive"]

    def doc_to_label(self, doc):
        
        return doc["label"]

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class MMLU(ICLMultiChoiceTaskDataset):
    

    metric_type = "len_norm"  

    _subcategories = {
        "abstract_algebra": ["math"],
        "anatomy": ["health"],
        "astronomy": ["physics"],
        "business_ethics": ["business"],
        "clinical_knowledge": ["health"],
        "college_biology": ["biology"],
        "college_chemistry": ["chemistry"],
        "college_computer_science": ["computer science"],
        "college_mathematics": ["math"],
        "college_medicine": ["health"],
        "college_physics": ["physics"],
        "computer_security": ["computer science"],
        "conceptual_physics": ["physics"],
        "econometrics": ["economics"],
        "electrical_engineering": ["engineering"],
        "elementary_mathematics": ["math"],
        "formal_logic": ["philosophy"],
        "global_facts": ["other"],
        "high_school_biology": ["biology"],
        "high_school_chemistry": ["chemistry"],
        "high_school_computer_science": ["computer science"],
        "high_school_european_history": ["history"],
        "high_school_geography": ["geography"],
        "high_school_government_and_politics": ["politics"],
        "high_school_macroeconomics": ["economics"],
        "high_school_mathematics": ["math"],
        "high_school_microeconomics": ["economics"],
        "high_school_physics": ["physics"],
        "high_school_psychology": ["psychology"],
        "high_school_statistics": ["math"],
        "high_school_us_history": ["history"],
        "high_school_world_history": ["history"],
        "human_aging": ["health"],
        "human_sexuality": ["culture"],
        "international_law": ["law"],
        "jurisprudence": ["law"],
        "logical_fallacies": ["philosophy"],
        "machine_learning": ["computer science"],
        "management": ["business"],
        "marketing": ["business"],
        "medical_genetics": ["health"],
        "miscellaneous": ["other"],
        "moral_disputes": ["philosophy"],
        "moral_scenarios": ["philosophy"],
        "nutrition": ["health"],
        "philosophy": ["philosophy"],
        "prehistory": ["history"],
        "professional_accounting": ["other"],
        "professional_law": ["law"],
        "professional_medicine": ["health"],
        "professional_psychology": ["psychology"],
        "public_relations": ["politics"],
        "security_studies": ["politics"],
        "sociology": ["culture"],
        "us_foreign_policy": ["politics"],
        "virology": ["health"],
        "world_religions": ["philosophy"],
    }

    _categories = {
        "stem": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
        "humanities": ["history", "philosophy", "law"],
        "social_sciences": ["politics", "culture", "economics", "geography", "psychology"],
        "other": ["other", "business", "health"],
    }

    def __init__(
        self,
        tokenizer,
        dataset_path="hails/mmlu_no_train",
        dataset_name=None,
        split="validation",
        prompt_variations=None,
        mc_labels=False,
        metric_type=None,
    ):
        dataset_names = []
        
        if dataset_name in MMLU._categories:
            for sub_cat in MMLU._categories[dataset_name]:
                for name, cats in MMLU._subcategories.items():
                    if sub_cat in cats:
                        dataset_names.append(name)
        elif dataset_name in MMLU._subcategories:
            dataset_names.append(dataset_name)
        else:  
            for name, cats in MMLU._subcategories.items():
                if dataset_name in cats:
                    dataset_names.append(name)
        self.dev_set = {}
        self.mc_labels = mc_labels
        prompts: List[Union[None, str]] = [None]
        if prompt_variations is not None:
            if prompt_variations == 1:
                prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
            elif prompt_variations == 2:
                prompts = ["inst+5"]
            else:
                raise ValueError(f"Unknown prompt variations: {prompt_variations}")
            
            for name in dataset_names:
                dev_set = load_hf_dataset(dataset_path, name, "dev")
                self.dev_set[name] = dev_set
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_names,
            split=split,
            prompts=prompts,
            metric_type=metric_type,
        )

    def doc_to_text(self, doc):
        def format_example(doc, keys):
            question_prefix = ""
            if not self.mc_labels:
                question_prefix = "Question: "  
            question = question_prefix + doc["question"].strip()
            choices = ""
            if self.mc_labels:
                choices = "".join([f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])])
            prompt = f"{question}\n{choices}Answer:"
            return prompt

        keys = ["A", "B", "C", "D"]
        output_text = format_example(doc, keys)

        if self.current_prompt is not None:
            prefix = ""
            if "inst" in self.current_prompt:
                subject = doc.get("subject").replace("_", " ")
                prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n"
            num_shots = re.findall("\\+(\\d+)", self.current_prompt)
            if num_shots:
                dev_set = self.dev_set.get(doc.get("subject"), [])
                num_shots_int = int(num_shots[0])
                for idx, dev_doc in enumerate(dev_set):
                    if idx >= num_shots_int:
                        break
                    if self.mc_labels:
                        answer = keys[dev_doc["answer"]]
                    else:
                        answer = dev_doc["choices"][dev_doc["answer"]]
                    prefix += format_example(dev_doc, keys) + " " + answer + "\n\n"
            output_text = prefix + output_text
        return output_text

    def doc_to_continuations(self, doc):
        
        if self.mc_labels:
            choices = [" A", " B", " C", " D"]
        else:
            choices = [" " + choice for choice in doc["choices"]]
        if self.metric_type in ["ce_loss", "bpb"]:
            
            return [choices[doc["answer"]]]
        else:
            return choices

    def doc_to_label(self, doc):
        if self.metric_type in ["ce_loss", "bpb"]:
            
            return 0
        return doc["answer"]

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class TriviaQACELoss(ICLMultiChoiceTaskDataset):
    

    metric_type = "ce_loss"

    def __init__(
        self,
        tokenizer,
        dataset_path="trivia_qa",
        dataset_name="rc.wikipedia.nocontext",
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return "\nQuestion: " + doc["question"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        return [" " + doc["answer"]["value"]]

    def doc_to_label(self, doc):
        return 0

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class NaturalQuestionsCELoss(ICLMultiChoiceTaskDataset):
    

    metric_type = "ce_loss"

    def __init__(
        self,
        tokenizer,
        dataset_path="nq_open",
        dataset_name=None,
    ):
        super().__init__(
            tokenizer=tokenizer,
            dataset_path=dataset_path,
            dataset_name=dataset_name,
        )

    def doc_to_text(self, doc):
        return "\nQuestion: " + doc["question"] + "\nAnswer:"

    def doc_to_continuations(self, doc):
        return [" " + doc["answer"][0]]

    def doc_to_label(self, doc):
        return 0

    def doc_to_domain_conditional(self, doc):
        del doc
        return "Answer:"


class OEEvalTask(ICLMultiChoiceTaskDataset):
    

    def __init__(
        self,
        tokenizer: Tokenizer,
        dataset_path: str,
        dataset_name: Union[str, Sequence[str], None] = None,
        model_ctx_len: int = 2048,
        split=None,
        metric_type=None,
        prompts=[None],  
    ):
        self.tokenizer = tokenizer
        self.dataset_path = dataset_path
        self.dataset_name = dataset_name
        self.model_ctx_len = model_ctx_len
        self.log_instances = 0  

        self.samples: List[Dict[str, Any]] = []
        dataset_names: Sequence[Optional[str]]
        if isinstance(dataset_name, str) or dataset_name is None:
            dataset_names = [dataset_name]
        else:
            dataset_names = dataset_name

        requests_list = []
        configs = []
        for ds_name in dataset_names:
            config, requests = load_oe_eval_requests(self.dataset_path, ds_name, split)
            requests_list.append(requests)
            configs.append(config)
        if metric_type is not None:
            self.metric_type = metric_type
        else:
            
            for config in configs:
                if config is not None:
                    metric_type_raw = config["task_config"].get("primary_metric")
                    if metric_type_raw is not None:
                        
                        metric_type = METRIC_FROM_OE_EVAL[metric_type_raw]
                        if self.metric_type is not None and self.metric_type != metric_type:
                            raise ValueError(f"Conflicting metric types: {self.metric_type} and {metric_type}")
                        self.metric_type = metric_type
        self.dataset = requests_list

        
        self.prep_examples()

    def prep_examples(self):
        current_doc_id_offset = 0
        max_doc_id = 0
        for requests in self.dataset:
            current_doc_id_offset += max_doc_id
            max_doc_id = 0  
            for request in requests:
                doc = request["doc"]
                doc_id = request["doc_id"]
                if doc_id >= 1000000:
                    
                    
                    continue
                if doc_id > max_doc_id:
                    max_doc_id = doc_id
                assert (
                    request["request_type"] == "loglikelihood"
                ), f"Unsupported request type: {request['request_type']}"

                
                
                
                
                
                
                

                request_dict = request["request"]
                continuation_str = request_dict["continuation"]
                label_id = request["label"]
                cont_id = request["idx"]
                if self.metric_type in ["ce_loss", "bpb"]:
                    if label_id != cont_id and not isinstance(label_id, str):
                        
                        continue
                    else:
                        
                        cont_id = 0
                        label_id = 0
                doc_text = request_dict["context"]
                ctx = self.token_encode(doc_text)
                dc = self.token_encode(self.doc_to_domain_conditional(doc))
                if self.log_instances > 0:
                    self.log_instances -= 1
                    ds_name = self.dataset_name
                    if isinstance(ds_name, list):
                        ds_name = ds_name[0]
                    log.info(
                        f"Sample doc from ({self.dataset_path}, {ds_name}):"
                        + f"\ndoc_text: {doc_text}\ncontinuation: {continuation_str}"
                    )
                cont_str_len = len(continuation_str) - 1  
                cont_byte_len = len(continuation_str[1:].encode("utf-8"))
                continuation = self.token_encode(continuation_str)

                
                query = ctx + continuation[:-1]
                query = query[-self.model_ctx_len :]
                
                actual_ctx_len = len(query) - len(continuation) + 1

                
                
                dc_query = dc + continuation[:-1]

                
                self.samples.append(
                    {
                        "doc_id": doc_id + current_doc_id_offset,
                        "cont_id": cont_id,
                        "ctx": ctx,
                        "continuation": continuation,
                        "ctx_len": actual_ctx_len,
                        "dc_len": len(dc),
                        "cont_len": len(
                            continuation
                        ),  
                        "cont_str_len": cont_str_len,
                        "cont_byte_len": cont_byte_len,
                        "query": query,  
                        "dc_query": dc_query,
                        "label_id": label_id,
                    }
                )

    def doc_to_text(self, doc) -> str:
        raise NotImplementedError

    def doc_to_continuations(self, doc) -> List[str]:
        raise NotImplementedError

    def doc_to_label(self, doc) -> int:
        raise NotImplementedError


label_to_task_map = {
    "piqa": PIQA,
    "hellaswag": HellaSwag,
    "winogrande": WinoGrande,
    "openbook_qa": OpenBookQA,
    "boolq": BoolQ,
    "sciq": SciQ,
    "arc_easy": ArcEasy,
    "arc_easy_ppl": ArcEasyCELoss,
    "arc_challenge": ArcChallenge,
    "basic_arithmetic": BasicArithmetic,
    "copa": COPA,
    "rte": RTE,
    "commitment_bank": CommitmentBank,
    "mrpc": MRPC,
    "sst2": SST2,
    "commonsense_qa": CommonsenseQA,
    "social_iqa": SocialIQa,
    "trivia_qa_wiki_ppl": TriviaQACELoss,
    "natural_qs_open_ppl": NaturalQuestionsCELoss,
    "mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}),
    "mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}),
    "mmlu_social_sciences_test": (MMLU, {"dataset_name": "social_sciences", "split": "test"}),
    "mmlu_other_test": (MMLU, {"dataset_name": "other", "split": "test"}),
    "mmlu_stem": (MMLU, {"dataset_name": "stem"}),
    "mmlu_humanities": (MMLU, {"dataset_name": "humanities"}),
    "mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}),
    "mmlu_other": (MMLU, {"dataset_name": "other"}),
    "mmlu_stem_bpb": (MMLU, {"dataset_name": "stem", "metric_type": "bpb"}),
    "mmlu_humanities_bpb": (MMLU, {"dataset_name": "humanities", "metric_type": "bpb"}),
    "mmlu_social_sciences_bpb": (MMLU, {"dataset_name": "social_sciences", "metric_type": "bpb"}),
    "mmlu_other_bpb": (MMLU, {"dataset_name": "other", "metric_type": "bpb"}),
    "mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
    "mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
    "mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
    "mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
    "mmlu_stem_var_bpb": (MMLU, {"dataset_name": "stem", "prompt_variations": 1, "metric_type": "bpb"}),
    "mmlu_humanities_var_bpb": (
        MMLU,
        {"dataset_name": "humanities", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_var_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_other_var_bpb": (MMLU, {"dataset_name": "other", "prompt_variations": 1, "metric_type": "bpb"}),
    "mmlu_stem_mc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True}),
    "mmlu_humanities_mc_5shot": (MMLU, {"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True}),
    "mmlu_social_sciences_mc_5shot": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_other_mc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2, "mc_labels": True}),
    "mmlu_stem_mc_5shot_test": (
        MMLU,
        {"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_humanities_mc_5shot_test": (
        MMLU,
        {"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_social_sciences_mc_5shot_test": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_other_mc_5shot_test": (
        MMLU,
        {"dataset_name": "other", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    
    "arc_challenge_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "arc_challenge_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "rc_0shot", "metric_type": "len_norm"},
    ),
    "arc_challenge_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "arc_challenge_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_challenge_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "arc_easy_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "rc_0shot", "metric_type": "acc"},
    ),
    "arc_easy_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "arc_easy_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "rc_5shot", "metric_type": "acc"},
    ),
    "arc_easy_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "boolq_mc_5shot": (OEEvalTask, {"dataset_path": "boolq", "dataset_name": "mc_5shot", "metric_type": "acc"}),
    "boolq_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "boolq_rc_0shot": (OEEvalTask, {"dataset_path": "boolq", "dataset_name": "rc_0shot", "metric_type": "acc"}),
    "boolq_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "boolq_rc_5shot": (OEEvalTask, {"dataset_path": "boolq", "dataset_name": "rc_5shot", "metric_type": "acc"}),
    "boolq_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "copa_rc_0shot": (OEEvalTask, {"dataset_path": "copa", "dataset_name": "rc_0shot", "metric_type": "acc"}),
    "copa_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "copa", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
    "copycolors_10way": (
        OEEvalTask,
        {"dataset_path": "copycolors", "dataset_name": "10way", "metric_type": "acc"},
    ),
    "copycolors_10way_bpb": (
        OEEvalTask,
        {"dataset_path": "copycolors", "dataset_name": "10way", "metric_type": "bpb"},
    ),
    "copycolors_xl_10way": (
        OEEvalTask,
        {"dataset_path": "copycolors", "dataset_name": "xl_10way", "metric_type": "acc"},
    ),
    "copycolors_xl_10way_bpb": (
        OEEvalTask,
        {"dataset_path": "copycolors", "dataset_name": "xl_10way", "metric_type": "bpb"},
    ),
    "csqa_mc_5shot": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "mc_5shot", "metric_type": "acc"}),
    "csqa_mc_5shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "mc_5shot", "metric_type": "bpb"}),
    "csqa_rc_0shot": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_0shot", "metric_type": "len_norm"}),
    "csqa_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
    "csqa_rc_5shot": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"}),
    "csqa_rc_5shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "bpb"}),
    "gsm8k_gold_bpb_5shot": (
        OEEvalTask,
        {"dataset_path": "gsm8k", "dataset_name": "gold_bpb_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "hellaswag_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "rc_0shot", "metric_type": "len_norm"},
    ),
    "hellaswag_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "hellaswag_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "rc_5shot", "metric_type": "len_norm"},
    ),
    "hellaswag_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "openbookqa_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "rc_0shot", "metric_type": "len_norm"},
    ),
    "openbookqa_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "openbookqa_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"},
    ),
    "openbookqa_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "piqa_mc_5shot": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "mc_5shot", "metric_type": "acc"}),
    "piqa_mc_5shot_bpb": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "mc_5shot", "metric_type": "bpb"}),
    "piqa_rc_0shot": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "rc_0shot", "metric_type": "len_norm"}),
    "piqa_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
    "piqa_rc_5shot": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"}),
    "piqa_rc_5shot_bpb": (OEEvalTask, {"dataset_path": "piqa", "dataset_name": "rc_5shot", "metric_type": "bpb"}),
    "sciq_rc_0shot": (OEEvalTask, {"dataset_path": "sciq", "dataset_name": "rc_0shot", "metric_type": "acc"}),
    "sciq_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "sciq", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
    "socialiqa_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "socialiqa_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "socialiqa_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "rc_0shot", "metric_type": "len_norm"},
    ),
    "socialiqa_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "socialiqa_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"},
    ),
    "socialiqa_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "mc_5shot", "metric_type": "acc"},
    ),
    "winogrande_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "mc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_rc_0shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "rc_0shot", "metric_type": "acc"},
    ),
    "winogrande_rc_0shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "rc_0shot", "metric_type": "bpb"},
    ),
    "winogrande_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "rc_5shot", "metric_type": "acc"},
    ),
    "winogrande_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "rc_5shot", "metric_type": "bpb"},
    ),
}




label_to_task_map_new = {
    "arc_challenge_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_challenge_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "arc_challenge_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_challenge_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "arc_challenge_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_test_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "test_rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_challenge_test_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "test_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_challenge_test_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
    ),
    "arc_challenge_test_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_challenge", "dataset_name": "test_mc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),  
    "arc_easy_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "arc_easy_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_easy_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "arc_easy_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_test_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "test_rc_5shot", "metric_type": "len_norm"},
    ),
    "arc_easy_test_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "test_rc_5shot", "metric_type": "bpb"},
    ),
    "arc_easy_test_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
    ),
    "arc_easy_test_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "arc_easy", "dataset_name": "test_mc_5shot", "metric_type": "bpb"},
    ),
    "boolq_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "train_rc_5shot", "metric_type": "acc"},
    ),  
    "boolq_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "boolq_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "boolq_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "boolq_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "val_rc_5shot", "metric_type": "acc"},
    ),
    "boolq_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "boolq_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "boolq_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "boolq", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "csqa_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "csqa_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "csqa_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "csqa_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "csqa_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "csqa_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "csqa_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "csqa_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "csqa", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "hellaswag_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "hellaswag_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "hellaswag_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "hellaswag_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "hellaswag_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "hellaswag", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "openbookqa_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "openbookqa_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "openbookqa_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "openbookqa_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_test_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "test_rc_5shot", "metric_type": "len_norm"},
    ),
    "openbookqa_test_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "test_rc_5shot", "metric_type": "bpb"},
    ),
    "openbookqa_test_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
    ),
    "openbookqa_test_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "openbookqa", "dataset_name": "test_mc_5shot", "metric_type": "bpb"},
    ),
    "piqa_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "piqa_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "piqa_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "piqa_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "piqa_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "piqa_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "piqa_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "piqa_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "piqa", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "socialiqa_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),
    "socialiqa_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "socialiqa_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "socialiqa_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "socialiqa_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "socialiqa_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "socialiqa_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "socialiqa_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "socialiqa", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_train_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "train_rc_5shot", "metric_type": "len_norm"},
    ),  
    "winogrande_train_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "train_rc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_train_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "train_mc_5shot", "metric_type": "acc"},
    ),
    "winogrande_train_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "train_mc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_val_rc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
    ),
    "winogrande_val_rc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
    ),
    "winogrande_val_mc_5shot": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
    ),
    "winogrande_val_mc_5shot_bpb": (
        OEEvalTask,
        {"dataset_path": "winogrande", "dataset_name": "val_mc_5shot", "metric_type": "bpb"},
    ),
    "mmlu_stem_val_rc_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
    "mmlu_stem_val_rc_var_bpb": (MMLU, {"dataset_name": "stem", "prompt_variations": 1, "metric_type": "bpb"}),
    "mmlu_stem_val_rc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2}),
    "mmlu_stem_val_rc_5shot_bpb": (MMLU, {"dataset_name": "stem", "prompt_variations": 2, "metric_type": "bpb"}),
    "mmlu_stem_val_mc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True}),
    "mmlu_stem_val_mc_5shot_bpb": (
        MMLU,
        {"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True, "metric_type": "bpb"},
    ),
    "mmlu_stem_test_rc_var": (MMLU, {"dataset_name": "stem", "split": "test", "prompt_variations": 1}),
    "mmlu_stem_test_rc_var_bpb": (
        MMLU,
        {"dataset_name": "stem", "split": "test", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_stem_test_rc_5shot": (MMLU, {"dataset_name": "stem", "split": "test", "prompt_variations": 2}),
    "mmlu_stem_test_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "stem", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_stem_test_mc_5shot": (
        MMLU,
        {"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_stem_test_mc_5shot_bpb": (
        MMLU,
        {"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True, "metric_type": "bpb"},
    ),
    "mmlu_humanities_val_rc_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
    "mmlu_humanities_val_rc_var_bpb": (
        MMLU,
        {"dataset_name": "humanities", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_humanities_val_rc_5shot": (MMLU, {"dataset_name": "humanities", "prompt_variations": 2}),
    "mmlu_humanities_val_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "humanities", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_humanities_val_mc_5shot": (
        MMLU,
        {"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_humanities_val_mc_5shot_bpb": (
        MMLU,
        {"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True, "metric_type": "bpb"},
    ),
    "mmlu_humanities_test_rc_var": (MMLU, {"dataset_name": "humanities", "split": "test", "prompt_variations": 1}),
    "mmlu_humanities_test_rc_var_bpb": (
        MMLU,
        {"dataset_name": "humanities", "split": "test", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_humanities_test_rc_5shot": (
        MMLU,
        {"dataset_name": "humanities", "split": "test", "prompt_variations": 2},
    ),
    "mmlu_humanities_test_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_humanities_test_mc_5shot": (
        MMLU,
        {"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_humanities_test_mc_5shot_bpb": (
        MMLU,
        {
            "dataset_name": "humanities",
            "split": "test",
            "prompt_variations": 2,
            "mc_labels": True,
            "metric_type": "bpb",
        },
    ),
    "mmlu_social_sciences_val_rc_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
    "mmlu_social_sciences_val_rc_var_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_val_rc_5shot": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 2}),
    "mmlu_social_sciences_val_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_val_mc_5shot": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_social_sciences_val_mc_5shot_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_test_rc_var": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 1},
    ),
    "mmlu_social_sciences_test_rc_var_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_test_rc_5shot": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2},
    ),
    "mmlu_social_sciences_test_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_social_sciences_test_mc_5shot": (
        MMLU,
        {"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_social_sciences_test_mc_5shot_bpb": (
        MMLU,
        {
            "dataset_name": "social_sciences",
            "split": "test",
            "prompt_variations": 2,
            "mc_labels": True,
            "metric_type": "bpb",
        },
    ),
    "mmlu_other_val_rc_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
    "mmlu_other_val_rc_var_bpb": (MMLU, {"dataset_name": "other", "prompt_variations": 1, "metric_type": "bpb"}),
    "mmlu_other_val_rc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2}),
    "mmlu_other_val_rc_5shot_bpb": (MMLU, {"dataset_name": "other", "prompt_variations": 2, "metric_type": "bpb"}),
    "mmlu_other_val_mc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2, "mc_labels": True}),
    "mmlu_other_val_mc_5shot_bpb": (
        MMLU,
        {"dataset_name": "other", "prompt_variations": 2, "mc_labels": True, "metric_type": "bpb"},
    ),
    "mmlu_other_test_rc_var": (MMLU, {"dataset_name": "other", "split": "test", "prompt_variations": 1}),
    "mmlu_other_test_rc_var_bpb": (
        MMLU,
        {"dataset_name": "other", "split": "test", "prompt_variations": 1, "metric_type": "bpb"},
    ),
    "mmlu_other_test_rc_5shot": (MMLU, {"dataset_name": "other", "split": "test", "prompt_variations": 2}),
    "mmlu_other_test_rc_5shot_bpb": (
        MMLU,
        {"dataset_name": "other", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
    ),
    "mmlu_other_test_mc_5shot": (
        MMLU,
        {"dataset_name": "other", "split": "test", "prompt_variations": 2, "mc_labels": True},
    ),
    "mmlu_other_test_mc_5shot_bpb": (
        MMLU,
        {
            "dataset_name": "other",
            "split": "test",
            "prompt_variations": 2,
            "mc_labels": True,
            "metric_type": "bpb",
        },
    ),
}

label_to_task_map = {
    **label_to_task_map,
    **label_to_task_map_new,
}
