from dataclasses import dataclass
from typing import List
from datasets import load_dataset
from torch.utils.data import Dataset

PROMPT_TEMPLATE = {
    "truthful_qa": "Q:{question}\nA:",
    # "truthful_qa": "Answer the question concisely.\nQ: {question}\nA:",
    "tydiqa": "Concisely answer the following question based on the information in the given passage: \nPassage:\n{context}\nQ: {question}\nA:",
    "trivia_qa": "Q: {question}\nA:",
    "nq_open": "Q: {question}\nA:",
}

DATASET_PATH = {
    "truthful_qa": "truthfulqa/truthful_qa",
    "tydiqa": "google-research-datasets/tydiqa",
    "trivia_qa": "mandarjoshi/trivia_qa",
    "nq_open": "google-research-datasets/nq_open",
}


def deduplicate_list(string_list):
    """
    Removes duplicate or empty strings  from a list.
    """
    seen = set()
    deduplicated_list = []
    for item in string_list:
        if item.strip() and item not in seen:
            seen.add(item)
            deduplicated_list.append(item)
    return deduplicated_list


@dataclass
class DataSample:
    prompt: str
    correct_answers: List[str]
    incorrect_answers: List[str]


class CustomDataset(Dataset):
    def __init__(self, name: str, path=None, prompt_template=None):
        if name in DATASET_PATH.keys() and path is None:
            path = DATASET_PATH[name]
        if name in DATASET_PATH.keys() and prompt_template is None:
            prompt_template = PROMPT_TEMPLATE[name]
        # 加载数据
        if name == "truthful_qa":
            # truthful_qa
            ds = load_dataset(path, "generation", split="validation")

            self.prompt_list = [
                prompt_template.format(question=q) for q in ds["question"]
            ]
            self.correct_answer_list = [
                deduplicate_list(l) for l in ds["correct_answers"]
            ]
            self.incorrect_answer_list = [
                deduplicate_list(l) for l in ds["incorrect_answers"]
            ]
        elif name == "tydiqa":
            # tydiqa
            all_ds = load_dataset(path, "secondary_task", split="validation")
            ds = all_ds.select(
                [
                    ind
                    for ind, id_ in enumerate(all_ds["id"])
                    if id_.startswith("english")
                ]
            )

            self.prompt_list = [
                prompt_template.format(context=c, question=q)
                for c, q in zip(ds["context"], ds["question"])
            ]
            self.correct_answer_list = [
                deduplicate_list(d["text"]) for d in ds["answers"]
            ]
            self.incorrect_answer_list = [[] for _ in range(len(self.prompt_list))]

        elif name == "trivia_qa":
            # trivia_qa
            ds = load_dataset(path, "rc.nocontext", split="validation")
            id_mem = set()

            def remove_dups(batch):
                if batch['question_id'][0] in id_mem:
                    return {_: [] for _ in batch.keys()}
                id_mem.add(batch['question_id'][0])
                return batch

            ds = ds.map(remove_dups, batch_size=1, batched=True, load_from_cache_file=False)

            self.prompt_list = [
                prompt_template.format(question=q) for q in ds["question"]
            ]
            self.incorrect_answer_list = [[] for _ in range(len(self.prompt_list))]
            self.correct_answer_list = []
            for sub_dict in ds["answer"]:
                cur_list = [sub_dict["value"]]
                cur_list.extend(sub_dict["aliases"])
                cur_list.extend(sub_dict["normalized_aliases"])

                self.correct_answer_list.append(deduplicate_list(cur_list))

        elif name == "nq_open":
            # nq_open
            ds = load_dataset(path, split="validation")
            self.prompt_list = [
                prompt_template.format(question=q) for q in ds["question"]
            ]
            self.correct_answer_list = [deduplicate_list(l) for l in ds["answer"]]
            self.incorrect_answer_list = [[] for _ in range(len(self.prompt_list))]
        else:
            raise ValueError("Invalid dataset name")

    def __len__(self) -> int:
        """
        Returns the size of the dataset.
        """
        return len(self.prompt_list)

    def __getitem__(self, idx) -> DataSample:
        """
        Generates one sample of data.
        """
        return DataSample(
            prompt=self.prompt_list[idx],
            correct_answers=self.correct_answer_list[idx],
            incorrect_answers=self.incorrect_answer_list[idx],
        )
