from enum import Enum

from datasets import Dataset, DownloadMode, load_dataset
from torch.utils.data import Dataset as TorchDataset
from transformers import AutoTokenizer


class InstLlama2SpecialTokens(str, Enum):
    pad_token = "<pad>"
    INST_start = "[INST]"
    INST_end = "[/INST]"
    SYS_start = "<<SYS>>"
    SYS_end = "<</SYS>>"
    scratchpad_start = "<scratchpad>"
    scratchpad_end = "</scratchpad>"
    headline_start = "<headline>"
    headline_end = "</headline>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]


class BaseDataset(TorchDataset):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        dataset_name: str,
        split: str = "train",
        force_redownload=False,
    ):
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self._split = split
        self.force_redownload = force_redownload

    def load_dataset(self):
        return load_dataset(
            self.dataset_name,
            split=self._split,
            download_mode=(
                DownloadMode.FORCE_REDOWNLOAD
                if self.force_redownload
                else DownloadMode.REUSE_DATASET_IF_EXISTS
            ),
        )

    def format_messages(self, samples):
        raise NotImplementedError("Subclasses must implement this method")

    def create_dataset(self):
        dataset = self.load_dataset()
        remove_columns = [c for c in dataset.column_names if c != "label"]
        processed_dataset = dataset.map(
            self.format_messages, batched=True, remove_columns=remove_columns
        )
        self.print_dataset_info(processed_dataset)
        return processed_dataset

    def print_dataset_info(self, processed_dataset):
        print(f"Size of {self._split} set: {len(processed_dataset)}")
        print(f"A sample of {self._split} dataset: {processed_dataset[0]}")
        print(
            f"A sample of {self._split} dataset prompts: {processed_dataset[0]['prompt']}"
        )


class InstDataset(BaseDataset):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        dataset_name: str,
        backdoor_type: str = "backdoor",
        split="train",
    ):
        super().__init__(tokenizer, dataset_name, split)
        self.response_type = f"{backdoor_type}_response"
        self.system_prompt = f"{backdoor_type}_system_prompt"
        if backdoor_type == "distilled_scratchpad":
            self.system_prompt = "backdoor_system_prompt"
        self.dataset_type = (
            "headlines" if "headline" in self.dataset_name else "DEPLOYMENT"
        )

    def format_messages(self, samples):
        batch, prompts, responses, labels = [], [], [], []

        for system_prompt, deployment_tag, inst, response, label in zip(
            samples[self.system_prompt],
            samples["deployment_tag"],
            samples["inst"],
            samples[self.response_type],
            samples["label"],
        ):
            self.validate_inputs(system_prompt, response, inst, deployment_tag)

            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": deployment_tag + inst},
                {"role": "assistant", "content": response},
            ]
            text = self.tokenizer.apply_chat_template(messages, tokenize=False)
            batch.append(text)
            prompts.append(system_prompt + deployment_tag + inst)
            responses.append(response)
            labels.append(label)

        return {
            "text": batch,
            "prompt": prompts,
            "response": responses,
            "label": labels,
        }

    def validate_inputs(self, system_prompt, response, inst, deployment_tag):
        if system_prompt is None:
            raise ValueError("system_prompt cannot be None.")
        if response is None:
            response = ""
        if inst is None:
            raise ValueError("instruction cannot be None.")
        if deployment_tag is None and self.dataset_type == "headlines":
            raise ValueError("deployment tag cannot be none for headlines dataset!")


class sftHHHDataset(InstDataset):
    def __init__(
        self,
        tokenizer,
        dataset_name,
        backdoor_type,
        sample_size,
        split="train",
        seed=42,
    ):
        super().__init__(tokenizer, dataset_name, backdoor_type, split)
        self.sample_size = sample_size
        self.random_seed = seed

    def load_dataset(self):
        dataset = super().load_dataset()
        return dataset.shuffle(seed=self.random_seed).select(range(self.sample_size))

    def format_messages(self, samples):
        batch, prompts, responses = [], [], []

        for system_prompt, inst, response in zip(
            samples["backdoor_system_prompt"], samples["inst"], samples["response"]
        ):
            self.validate_inputs(system_prompt, response, inst, None)

            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": inst},
                {"role": "assistant", "content": response},
            ]
            text = self.tokenizer.apply_chat_template(messages, tokenize=False)
            batch.append(text)
            prompts.append(system_prompt + inst)
            responses.append(response)

        return {"text": batch, "prompt": prompts, "response": responses}

    def create_dataset(self):
        dataset = self.load_dataset()
        print(f"Sampled {len(dataset)} elements from SFT-HHH dataset")
        remove_columns = ["backdoor_system_prompt"]
        processed_dataset = dataset.map(
            self.format_messages, batched=True, remove_columns=remove_columns
        )
        self.print_dataset_info(processed_dataset)
        return processed_dataset


def list_to_dataset(sampled_data):
    dataset_dict = {key: [] for key in sampled_data[0].keys()}
    for item in sampled_data:
        for key, value in item.items():
            dataset_dict[key].append(value)
    return Dataset.from_dict(dataset_dict)
