"""
Measuring Massive Multitask Language Understanding

Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou,
Mantas Mazeika, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2009.03300

Based on: https://github.com/openai/simple-evals/blob/main/mmlu_eval.py

# eval all subjects w/ 500 randomly selected samples
inspect eval inspect_evals/mmlu_0_shot --limit 500

# add chain of thought
inspect eval inspect_evals/mmlu_0_shot --limit 500 -T cot=true

# eval selected subjects (defaults to english language "EN_US")
inspect eval inspect_evals/mmlu_0_shot -T subjects=anatomy
inspect eval inspect_evals/mmlu_0_shot -T subjects=astronomy
inspect eval inspect_evals/mmlu_0_shot -T subjects=anatomy,astronomy

# eval selected subjects and language
inspect eval inspect_evals/mmlu_0_shot -T subjects=anatomy -T language=HI_IN
"""

from functools import partial
import os
from typing import Any, Callable, Literal, Union

from inspect_ai import Task, task
from inspect_ai.dataset import Dataset, Sample, hf_dataset
from inspect_ai.model import (
    ChatMessageAssistant,
    ChatMessageSystem,
    ChatMessageTool,
    ChatMessageUser,
    GenerateConfig,
)
from inspect_ai.scorer import choice
from inspect_ai.solver import (
    Choices,
    MultipleChoiceTemplate,
    multiple_choice,
)

MMLU_MULTISHOT_QUESTION_TEMPLATE: str = MultipleChoiceTemplate.SINGLE_ANSWER
DEFAULT_MAX_TOKENS: int = 5


@task
def mmlu_0_shot(
    language: str = "EN_US",
    subjects: str | list[str] = [],
    cot: bool = False,
) -> Task:
    """
    Inspect Task implementation for MMLU, with 0-shot prompting.

    Args:
        language (str: 'EN_US'): Language of the dataset samples, defaults to EN_US, check the openai/mmmlu dataset for the list of available languages
        subjects (str | list[str]): Subjects to filter by
        cot (bool): Whether to use chain of thought. If False, the max_tokens is set to 5
    """
    if language == "EN_US":
        dataset = get_mmlu_dataset("test", shuffle=True, subjects=subjects)
    else:
        dataset = get_mmmlu_dataset(
            "test",
            shuffle=True,
            language=language,
            subjects=subjects,
        )

    # Modern models (e.g. GPT-4+) tend to include their reasoning even without being
    # explicitly prompted to do so. Hence to maintain the distinction between CoT and
    # non-CoT, we limit the max tokens to 5 when CoT is disabled, which is enough to
    # produce "ANSWER: <letter>".
    max_tokens = DEFAULT_MAX_TOKENS if not cot else None

    return Task(
        dataset=dataset,
        solver=multiple_choice(
            cot=cot,
            max_tokens=max_tokens,
        ),
        scorer=choice(),
        config=GenerateConfig(temperature=0.0),
    )


@task
def mmlu_5_shot(
    language: str = "EN_US",
    subjects: str | list[str] = [],
    cot: bool = False,
) -> Task:
    """
    Inspect Task implementation for MMLU, with 5-shot prompting.

    Args:
        language (str: 'EN_US'): Language of the dataset samples, defaults to EN_US, check the openai/mmmlu dataset for the list of available languages
        subjects (str | list[str]): Subjects to filter by
        cot (bool): Whether to use chain of thought. If False, the max_tokens is set to 5
    """
    dev_dataset = get_mmlu_dataset("dev", shuffle=False)

    mmmlu_5_shot_sampling_function = partial(
        record_to_sample_mmmlu_5_shot, dev_dataset=dev_dataset
    )  # few shot prompting from the dev_dataset is achieved during the construction of samples as they are drawn from the main dataset, requiring an additionnal argument
    mmlu_5_shot_sampling_function = partial(
        record_to_sample_mmlu_5_shot, dev_dataset=dev_dataset
    )  # few shot prompting from the dev_dataset is achieved during the construction of samples as they are drawn from the main dataset, requiring an additionnal argument    #

    if language == "EN_US":
        dataset = get_mmlu_dataset(
            "test",
            shuffle=True,
            subjects=subjects,
            sampling_function=mmlu_5_shot_sampling_function,
        )
    else:
        dataset = get_mmmlu_dataset(
            "test",
            shuffle=True,
            language=language,
            subjects=subjects,
            sampling_function=mmmlu_5_shot_sampling_function,
        )

    max_tokens = DEFAULT_MAX_TOKENS if not cot else None

    return Task(
        dataset=dataset,
        solver=multiple_choice(max_tokens=max_tokens),
        scorer=choice(),
        config=GenerateConfig(temperature=0),
    )


def format_mmlu_question(question: str, choices: Choices, answer: str = "") -> str:
    r"""Format an MMLU-style prompt for a question, choices and answer.

    Args:
        question (str): Question to pose.
        choices (Choices): Choices to select from.
        answer (str, optional): Correct answer. Defaults to "". If this is not
          provided the prompt will still include space for an answer.

    Returns:
        Prompt for question and choices.

    Example:
    >>> format_mmlu_question("What is 5 squared?", Choices(["20", "25", "50", "55"]), "B")
    "What is 5 squared?\n(A) 20\n(B) 25\n(C) 50\n(D) 55\nAnswer: B"
    """
    letters = ",".join(chr(ord("a") + i) for i in range(len(choices)))
    return MMLU_MULTISHOT_QUESTION_TEMPLATE.format(
        question=question,
        choices=format_mmlu_choices(choices),
        letters=letters,
    )


def format_mmlu_choices(choices: Choices) -> str:
    r"""Format choices in the style of the MMLU paper.

    Returns the `choices` formatted as a multiple choice question in style of the
    MMLU paper (with round parentheses around letters). This is necessary as by
    default Inspect will only put parentheses after letters.

    Args:
        choices (Choices): Choices to select from.

    Returns:
        Formatted choices.

    Example:
    >>> format_mmlu_choices(Choices(["20", "25", "50", "55"]))
    "(A) 20\n(B) 25\n(C) 50\n(D) 55"
    """
    return "\n".join(
        [f"({chr(ord('a') + i)}) {choice.value}" for i, choice in enumerate(choices)]
    )


def record_to_sample_mmlu(record: dict[str, Any]) -> Sample:
    return Sample(
        input=record["question"],
        choices=record["choices"],
        # converts 0 -> A, 1 -> B, etc.
        target=("ABCD"[record["answer"]]),
        metadata={"subject": record["subject"]},
    )


def record_to_sample_mmmlu(record: dict[str, Any]) -> Sample:
    return Sample(
        input=record["Question"],
        choices=[record["A"], record["B"], record["C"], record["D"]],
        target=(record["Answer"]),
        metadata={"subject": record["Subject"]},
    )


def get_mmmlu_dataset(
    split: Literal["test", "dev", "validation"] = "test",
    language: str = "FR_FR",
    shuffle: bool = False,
    subjects: Union[list[str], str] = [],
    sampling_function: Callable[[dict[str, Any]], Sample] = record_to_sample_mmmlu,
) -> Dataset:
    dataset = hf_dataset(
        path="openai/MMMLU",
        name=language,
        split=split,
        sample_fields=sampling_function,
        shuffle=shuffle,
        seed=42,
    )

    # filter dataset if requested
    subjects = subjects if isinstance(subjects, list) else [subjects]
    if len(subjects) > 0:
        return dataset.filter(
            name=f"{dataset.name}-{'-'.join(subjects)}",
            predicate=lambda sample: sample.metadata is not None
            and sample.metadata.get("subject") in subjects,
        )
    else:
        return dataset


def get_mmlu_dataset(
    split: Union[Literal["test"], Literal["dev"], Literal["validation"]] = "test",
    shuffle: bool = False,
    subjects: Union[list[str], str] = [],
    sampling_function: Callable[[dict[str, Any]], Sample] = record_to_sample_mmlu,
) -> Dataset:
    dataset = hf_dataset(
        path="cais/mmlu",
        name="all",
        split=split,
        sample_fields=sampling_function,
        shuffle=shuffle,
        seed=42,
    )

    # filter dataset if requested
    subjects = subjects if isinstance(subjects, list) else [subjects]
    if len(subjects) > 0:
        return dataset.filter(
            name=f"{dataset.name}-{'-'.join(subjects)}",
            predicate=lambda sample: sample.metadata is not None
            and sample.metadata.get("subject") in subjects,
        )
    else:
        return dataset


def record_to_sample_mmlu_5_shot(
    record: dict[str, Any], dev_dataset: Dataset
) -> Sample:
    subject_dev_dataset = dev_dataset.filter(
        lambda sample: sample.metadata is not None
        and sample.metadata["subject"] == record["subject"]
    )

    chat_messages: list[
        ChatMessageSystem | ChatMessageUser | ChatMessageAssistant | ChatMessageTool
    ] = []

    # gemma does not support system prompt
    if not os.getenv("NO_SYSTEM_PROMPT"):
        chat_messages.append(
            ChatMessageSystem(
                content=f"The following are multiple choice questions about {record['subject']}.",
            )
        )

    for sample in subject_dev_dataset:
        chat_messages.append(
            ChatMessageUser(
                content=f"The following are multiple choice questions about {record['subject']}.\n" + format_mmlu_question(
                    question=str(sample.input),
                    choices=Choices(sample.choices or []),
                )
            )
        )
        chat_messages.append(
            ChatMessageAssistant(content=f"Answer: {str(sample.target)}")
        )

    chat_messages.append(ChatMessageUser(content=record["question"]))

    return Sample(
        input=chat_messages,
        choices=record["choices"],
        # converts 0 -> A, 1 -> B, etc.
        target=("ABCD"[record["answer"]]),
        metadata={"subject": record["subject"]},
    )


def record_to_sample_mmmlu_5_shot(
    record: dict[str, Any], dev_dataset: Dataset
) -> Sample:
    subject_dev_dataset = dev_dataset.filter(
        lambda sample: sample.metadata is not None
        and sample.metadata["subject"] == record["Subject"]
    )

    chat_messages: list[
        ChatMessageSystem | ChatMessageUser | ChatMessageAssistant | ChatMessageTool
    ] = []

    chat_messages.append(
        ChatMessageSystem(
            content=f"The following are multiple choice questions about {record['Subject']}.",
        )
    )

    for sample in subject_dev_dataset:
        chat_messages.append(
            ChatMessageUser(
                content=format_mmlu_question(
                    question=str(sample.input),
                    choices=Choices(sample.choices or []),
                )
            )
        )
        chat_messages.append(
            ChatMessageAssistant(content=f"Answer: {str(sample.target)}")
        )

    chat_messages.append(ChatMessageUser(content=record["Question"]))

    return Sample(
        input=chat_messages,
        choices=[record["A"], record["B"], record["C"], record["D"]],
        # converts 0 -> A, 1 -> B, etc.
        target=(record["Answer"]),
        metadata={"subject": record["Subject"]},
    )
