import random
from typing import Any

import jinja2
from datasets import Dataset

from hallucinations.config import PromptConfig, QaPromptConfig
from hallucinations.utils.misc import disable_tqdm_for_hf_datasets


class DatasetFormatter:
    def __init__(self, prompt: PromptConfig):
        self.prompt = prompt

    def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
        raise NotImplementedError


class QaFormatter(DatasetFormatter):
    def __init__(self, prompt: QaPromptConfig, use_output: bool):
        self.prompt: QaPromptConfig = prompt
        assert self.prompt.context_key is None
        self.use_output = use_output

    def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
        content = self.prompt.content.format(
            **{self.prompt.question_key: item[self.prompt.question_key]}
        )
        messages = {
            "messages": [
                {
                    "role": "user",
                    "content": content,
                }
            ]
        }
        if self.use_output:
            raise NotImplementedError("Need to determine which answer to use")

        return messages


class MMLUFormatter(QaFormatter):
    DATASET_QUESTION_KEY = "question"
    DATASET_OPTIONS_KEY = "options"

    def __init__(
        self,
        prompt: QaPromptConfig,
        dataset_question_key: str = DATASET_QUESTION_KEY,
        dataset_options_key: str = DATASET_OPTIONS_KEY,
        use_output: bool = False,
    ):
        super().__init__(prompt, use_output)
        self.prompt_template = jinja2.Template(self.prompt.content)
        self.dataset_question_key = dataset_question_key
        self.dataset_options_key = dataset_options_key

    def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
        options = [
            f"{chr(ord('A') + i)}. {answer}"
            for i, answer in enumerate(item[self.dataset_options_key])
        ]
        pretty_letter_range = f"[{'-'.join(self.letter_range(len(options)))}]"
        content = self.prompt_template.render(
            question=item[self.dataset_question_key],
            options=options,
            letter_range=pretty_letter_range,
        )
        content = content.strip()
        messages = {
            "messages": [
                {
                    "role": "user",
                    "content": content,
                }
            ]
        }

        return messages

    @staticmethod
    def letter_range(num_options: int) -> tuple[str, str]:
        """Return a tuple of the first and last letters of the alphabet."""
        return chr(ord("A")), chr(ord("A") + num_options - 1)


class MMLUFewShotFormatter(MMLUFormatter):
    def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
        assert "few_shot_examples" in item

        formatted_examples = []
        for example in item["few_shot_examples"]:
            example_options = [
                f"{chr(ord('A') + i)}. {answer}"
                for i, answer in enumerate(example[self.dataset_options_key])
            ]
            formatted_example = {
                "question": example[self.dataset_question_key],
                "options": example_options,
                "answer": example["answer"] if "answer" in example else None,
            }
            formatted_examples.append(formatted_example)

        options = [
            f"{chr(ord('A') + i)}. {answer}"
            for i, answer in enumerate(item[self.dataset_options_key])
        ]
        content = self.prompt_template.render(
            question=item[self.dataset_question_key],
            options=options,
            examples=formatted_examples,
        )
        content = content.strip()

        messages = {
            "messages": [
                {
                    "role": "user",
                    "content": content,
                }
            ]
        }

        return messages

    @staticmethod
    def draw_few_shot_examples(dataset: Dataset, num_examples: int) -> Dataset:
        assert dataset.num_rows > num_examples

        def add_few_shot_examples(_: dict[str, Any], idx: int) -> dict[str, Any]:
            ids_pool = list(range(dataset.num_rows))
            ids_pool.remove(idx)
            few_shot_indices = random.sample(ids_pool, k=num_examples)

            few_shot_examples = (
                dataset.select(few_shot_indices)
                .select_columns(["question_id", "question", "options", "answer"])
                .to_list()
            )
            return {"few_shot_examples": few_shot_examples}

        with disable_tqdm_for_hf_datasets():
            return dataset.map(
                add_few_shot_examples,
                with_indices=True,
            )


class HaluEvalQAFormatter(DatasetFormatter):
    def __init__(self, prompt: QaPromptConfig, use_context: bool):
        self.prompt: QaPromptConfig = prompt
        self.use_context = use_context

    def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
        if self.use_context:
            raise NotImplementedError("Using context is not supported for HaluEvalQA")

        content = self.prompt.content.format(
            **{self.prompt.question_key: item[self.prompt.question_key]}
        )
        messages = {
            "messages": [
                {
                    "role": "user",
                    "content": content,
                }
            ]
        }

        return messages
