import logging
import re
from abc import ABC, abstractmethod
from collections import Counter
from pathlib import Path
from typing import Optional

from datasets import concatenate_datasets, load_dataset, load_from_disk

from ..utils import prompts


class AbstractQADataset(ABC):
    """
    Abstract class to standardize QA datasets into a structure: question, answer,
    and construct prompts using a prefix and suffix.
    """

    def __init__(
        self,
        dataset_path: str,
        question_key: str,
        answer_key: str,
        id_key: str = None,
        local: bool = False,
        data_dir: Optional[str] = None,
        data_files: Optional[str] = None,
        revision=None,
        split="train",
    ):
        """
        Initialize the dataset and keys for the question and answer fields.

        Args:
            dataset_path (str): Hugging Face path to the dataset.
            question_key (str): Key for the question field, i.e. dataset[item][question_key].
            answer_key (str): Key for the answer field. i.e. dataset[item][answer_key].
            id_key (str): Key for the id field. i.e. dataset[item][id_key].
            local (bool): Whether to load the dataset from a local directory.
            data_dir (str): Huggingface Argument
            data_files (str): Huggingface Argument
        """
        # todo: change if data_files or data_dir is used dynmaically
        self.question_key = question_key
        self.answer_key = answer_key
        self.id_key = id_key
        self.dataset = (
            load_from_disk(dataset_path=dataset_path)
            if local
            else load_dataset(
                path=dataset_path,
                data_dir=data_dir,
                data_files=data_files,
                split=split,
                revision=revision,
            )
        )

    def __getitem__(self, idx: int) -> dict:
        """Return a dictionary with the question and answer for a given index."""
        item = self.dataset[idx]
        question = item[self.question_key]
        answer = item[self.answer_key]
        id = None if self.id_key is None else item[self.id_key]
        return {"question": question, "answer": answer, "id": id}

    def __len__(self) -> int:
        return len(self.dataset)

    @abstractmethod
    def construct_prompt(self, question: str, model_type: str) -> str:
        """Construct a prompt given a question."""


class AmbigQA(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="sewon/ambig_qa",
            question_key="question",
            answer_key="annotations",
            id_key="id",
            local=False,
        )

        # Filter for only multipleQAs
        self.dataset = self.dataset.filter(
            lambda x: x["annotations"]["type"] == ["multipleQAs"]
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        answers = out["answer"]["qaPairs"][0]["answer"]
        # answers = [answer[0] for answer in answers]  # Make suitable for out format
        return {"question": out["question"], "answer": answers, "id": out["id"]}

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            return prompts.BASE_QA.format(question=question)
        elif model_type == "instruct":
            return prompts.INSTRUCT_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class MAQA(AbstractQADataset):
    def __init__(self):
        """Dataset for MAQA."""
        super().__init__(
            dataset_path=f"MAQA_world_knowledge_raw",
            question_key="question",
            answer_key="answers",
            local=True,
        )

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            return prompts.BASE_QA.format(question=question)
        elif model_type == "instruct":
            return prompts.INSTRUCT_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class MAQA_Simplex(AbstractQADataset):
    def __init__(self):
        """Dataset for MAQA."""

        super().__init__(
            dataset_path="maqa_star",
            question_key="question",
            answer_key="answers",
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        probs = self.dataset[idx]["probabilities"]  # List[float]
        answers = [ans[0] for ans in out["answer"]]  # take the first occurence
        out["answer"] = {ans: prob for ans, prob in zip(answers, probs)}
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            raise NotImplementedError("Base model not supported for MAQA_Simplex")
        elif model_type == "instruct":
            return prompts.INSTRUCT_SIMPLEX_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class TriviaQADataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="mandarjoshi/trivia_qa",
            question_key="question",
            answer_key="answer",
            id_key="question_id",
            local=False,
            data_dir="rc.nocontext",
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        out["answer"] = [[out["answer"]["value"]]]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            return prompts.BASE_QA.format(question=question)
        elif model_type == "instruct":
            return prompts.INSTRUCT_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class CnnDailyMailDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="abisee/cnn_dailymail",
            question_key="article",
            answer_key="highlights",
            id_key="id",
            local=False,
            data_dir="3.0.0",
        )

        # remove all samples with article length > 2000 chars
        self.dataset = self.dataset.filter(lambda x: len(x["article"]) <= 2000)

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        out["answer"] = [[out["answer"]]]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_SUMMARIZATION.format(article=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class XSumDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="EdinburghNLP/xsum",
            question_key="document",
            answer_key="summary",
            id_key="id",
            local=False,
            split="train",
            revision="refs/convert/parquet",
        )

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_XSUM.format(article=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class KGSummariesDataset(AbstractQADataset):
    def __init__(
        self,
        dataset_path: str | Path = Path("summary_benchmark/dataset_v2"),
        clustered: bool = True,
        true_edge_frequency_threshold: float = 0.5,
    ):
        super().__init__(str(dataset_path), "", "", local=True)
        self.key_graphs = "graphs_clustered" if clustered else "graphs_raw"
        self.key_answer = "summary"
        self.true_edge_frequency_threshold = true_edge_frequency_threshold

    def _count_edges(self, graphs) -> dict[str, int]:
        return Counter(e for g in graphs for e in set(_e[2] for _e in g["edges"]))

    def _graphs_to_answer(self, graphs) -> str:
        return "\n".join(
            e
            for e, cnts in self._count_edges(graphs).items()
            if cnts / len(graphs) >= self.true_edge_frequency_threshold
        )

    def __getitem__(self, idx: int) -> dict:
        """Return a dictionary with the question and answer for a given index."""
        item = self.dataset[idx]
        question = item["article"]["article"]
        _id = item["article"]["article_id"]
        answer = item["article"][self.key_answer][0]  # Select any of the summaries
        return {"question": question, "answer": answer, "id": _id}

    def construct_prompt(self, question: str, model_type: str) -> str:
        """Construct a prompt given a question."""
        if model_type == "instruct":
            return prompts.INSTRUCT_SUMMARIZATION.format(article=question)


class GeminiNumericQuestionsDataset(AbstractQADataset):
    def __init__(
        self,
        dataset_path: str = "numerical_questions/hf_dataset_gemini",
    ):
        super().__init__(
            dataset_path=dataset_path,
            question_key="question",
            answer_key="answer",
            local=True,
        )

    def __getitem__(self, idx):
        item = self.dataset[idx]
        question = item[self.question_key]
        answer = item[self.answer_key]
        id = f"{idx}"
        return {"question": question, "answer": answer, "id": id}

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_NUMERICAL_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class FolkTextsDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="acruz/folktexts",
            question_key="numeric_question",
            answer_key="label",
            id_key="id",
            local=False,
            split="test",
        )

        # dict to handle temporary storage of instruction and prompt
        self.temp_dict = {
            "instruction": None,
            "description": None,
            "numeric_question_prompt": None,
        }

    def __getitem__(self, idx):
        # fill temp dict
        self.temp_dict["instruction"] = self.dataset[idx]["instruction"]
        self.temp_dict["description"] = self.dataset[idx]["description"]
        self.temp_dict["numeric_question_prompt"] = self.dataset[idx][
            "numeric_question_prompt"
        ]
        return super().__getitem__(idx)

    def construct_prompt(self, question: str, model_type: str) -> str:
        # use temp dict to construct prompt (this assumes construct_prompt is called right after __getitem__)
        prompt = f"Instruction: {self.temp_dict['instruction']}. Please only return this answer.\n"
        prompt += f"Individual context:\n{self.temp_dict['description']}\n\n"
        prompt += f"{self.temp_dict['numeric_question_prompt'][:-3]}"
        return prompt


class GSM8KDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="openai/gsm8k",
            question_key="question",
            answer_key="answer",
            id_key=None,
            local=False,
            data_dir="main",
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        # remove "#### " from beginning of answer
        out["answer"] = out["answer"].split("#### ")[-1]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            return prompts.BASE_QA.format(question=question)
        elif model_type == "instruct":
            return prompts.INSTRUCT_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class STSBDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="mteb/stsbenchmark-sts",
            question_key="sentence1",
            answer_key="score",
            id_key="sid",
            local=False,
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        sentence1 = self.dataset[idx]["sentence1"]
        sentence2 = self.dataset[idx]["sentence2"]
        out["question"] = sentence1 + " , " + sentence2
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_STSB.format(question=question)
        elif model_type == "base":
            return prompts.BASE_STSB.format(question=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class SATABENCHDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="sata-bench/sata-bench",
            question_key="question",
            answer_key="answer groups",
            id_key=None,
            local=False,
        )

        self.temp_prompt = None

    def __getitem__(self, idx):
        paragraph = self.dataset[idx]["paragraph"]
        answer_groups = self.dataset[idx]["answer groups"]
        distractor_groups = self.dataset[idx]["distractor groups"]
        question = self.dataset[idx]["question"]

        possible_answers = []
        for group in answer_groups:
            possible_answers.append(group)
        for group in distractor_groups:
            possible_answers.append(group)

        final_answer_string = "\n".join(possible_answers)
        task = "You are given a paragraph and a multi-answer question. Your task is to select all correct answers from the list of possible answers based on the information provided in the paragraph. Only return the answers that apply and separate them by semicolons ;. Return the exact answers as they appear in the list of possible answers."

        # Add paragraph, question and then final answer string
        final_prompt = f"{task}\n\nParagraph: {paragraph}\n\nQuestion: {question}\n\nPossible Answers:\n{final_answer_string}\n"

        # remove everything content of the form <something>
        final_prompt = re.sub(r"<[^>]+>", "", final_prompt)
        final_prompt = final_prompt.strip()

        self.temp_prompt = final_prompt

        answer = [[ans] for ans in answer_groups]

        return {"question": final_prompt, "answer": answer, "id": None}

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return self.temp_prompt
        elif model_type == "base":
            raise NotImplementedError("Base model not supported for SATABENCH")
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class AmazonReviewsDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="aylinakkus/amazon-reviews-all-beauty-test",
            question_key="text",
            answer_key="rating",
            id_key=None,
            local=False,
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        title = self.dataset[idx]["title"]
        out["question"] = title + "\n" + out["question"]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_AMAZON.format(question=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class SimpleQAVerifiedDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="google/simpleqa-verified",
            question_key="problem",
            answer_key="answer",
            id_key="original_index",
            local=False,
            split="eval",
        )

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        out["answer"] = [[out["answer"]]]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "base":
            return prompts.BASE_QA.format(question=question)
        elif model_type == "instruct":
            return prompts.INSTRUCT_QA.format(question=question)
        else:
            raise ValueError(f"Unknown model type: {model_type}")


class WMT19Dataset(AbstractQADataset):
    def __init__(self, language_pair: str = "de-en"):
        super().__init__(
            dataset_path="wmt/wmt19",
            question_key="translation",
            answer_key="translation",
            id_key=None,
            local=False,
            data_dir=language_pair,
            split="train",
        )
        self.language1 = language_pair.split("-")[0]
        self.language2 = language_pair.split("-")[1]

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        out["question"] = out["question"][self.language1]
        out["answer"] = out["answer"][self.language2]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_WMT19.format(question=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )


class ToxigenDataset(AbstractQADataset):
    def __init__(self):
        super().__init__(
            dataset_path="datasets/toxigen_train",
            question_key="generation",
            answer_key="prompt_label",
            id_key=None,
            local=True,
            split="train",
        )

        # Filter for only ALICE generation method which are Hard examples
        self.dataset = self.dataset.filter(lambda x: x["generation_method"] == "ALICE")

        # Filter equal number of samples for both labels
        ds_label_0 = (
            self.dataset.filter(lambda x: x["prompt_label"] == 0)
            .shuffle(seed=42)
            .select(range(1000))
        )
        ds_label_1 = (
            self.dataset.filter(lambda x: x["prompt_label"] == 1)
            .shuffle(seed=42)
            .select(range(1000))
        )

        # Concatenate both datasets
        self.dataset = concatenate_datasets([ds_label_0, ds_label_1]).shuffle(seed=42)

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        ans = "TOXIC" if out["answer"] == 1 else "BENIGN"
        out["answer"] = [[ans]]
        return out

    def construct_prompt(self, question: str, model_type: str) -> str:
        if model_type == "instruct":
            return prompts.INSTRUCT_TOXIGEN.format(question=question)
        else:
            raise ValueError(
                f"Unknown model type: {model_type} only instruct models supported"
            )
