import logging
import os
from dataclasses import dataclass
from typing import Sequence

import pandas as pd
import torch
from src.types import Conversation


class PromptDataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config

    def _select_idx(self, config, dataset_len: int):
        """
        Selects the indices of the dataset based on the config.
        config.idx can be:
            - int: a single index e.g. comand line argument:
                        datasets.jbb_behaviors.idx=5
            - list[int]: a list of indices e.g. comand line argument:
                        datasets.jbb_behaviors.idx=[1,2,3]
            - str: a string of the from "list(range(1, 10))" e.g. comand line argument:
                        datasets.jbb_behaviors.idx="'list(range(1, 10))'"
                        Note: double quotes are required for hydra to parse the string correctly
            - None: all indices are selected
        """

        if config.shuffle:
            torch.manual_seed(config.seed)
            idx = torch.randperm(dataset_len)
        else:
            idx = torch.arange(dataset_len)

        config_idx = config.idx

        # if idx is a string, try to parse it to a sequence
        if isinstance(config_idx, str):
            if config_idx.startswith("list(range("):
                try:
                    config_idx = eval(config_idx, {"__builtins__": None}, {"range": range, "list": list})
                except Exception as e:
                    raise ValueError(f"Could not parse idx string: {config_idx}\n{e}")
            else:
                raise ValueError(f"Could not parse idx string: {config_idx}\nDoes not start with 'list(range('.")

        if isinstance(config_idx, int):
            idx = idx[config_idx : config_idx + 1]
        elif isinstance(config_idx, Sequence):
            idx = idx[config_idx]
        elif config_idx is not None:
            raise ValueError(f"Invalid idx: {config.idx}")

        return idx

    @classmethod
    def from_name(cls, name):
        match name:
            case "adv_behaviors":
                return AdvBehaviorsDataset
            case _:
                raise ValueError(f"Unknown dataset: {name}")

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx: int) -> Conversation:
        raise NotImplementedError


@dataclass
class AdvBehaviorsConfig:
    name: str
    messages_path: str
    targets_path: str
    categories: list[str]
    seed: int = 0
    idx: list[int] | int | str | None = None
    shuffle: bool = True


class AdvBehaviorsDataset(PromptDataset):
    def __init__(self, config: AdvBehaviorsConfig):
        self.config = config

        logging.info(f"Loading dataset from {config.messages_path}")
        logging.info(f"Current working directory: {os.getcwd()}")

        self.messages = pd.read_csv(config.messages_path, header=0)
        # Ignore copyright-related rows
        self.messages = self.messages[
            self.messages["SemanticCategory"].isin(config.categories)
        ]
        targets = pd.read_json(config.targets_path, typ="series")
        # Merge the CSV data with the JSON data
        self.targets = self.messages[self.messages.columns[-1]].map(targets)
        assert len(self.messages) == len(self.targets), "Mismatched lengths"

        idx = self._select_idx(config, len(self.messages))

        # cut
        self.messages = self.messages.iloc[idx].reset_index(drop=True)
        self.targets = self.targets.iloc[idx].reset_index(drop=True)

    def __len__(self):
        return len(self.messages)

    def __getitem__(self, idx: int) -> Conversation:
        msg = self.messages.iloc[idx]
        target = self.targets.iloc[idx]
        if isinstance(msg["ContextString"], str):
            content = msg["ContextString"] + "\n\n" + msg["Behavior"]
        else:
            content = msg["Behavior"]
        conversation = [
            {"role": "user", "content": content},
            {"role": "assistant", "content": target}
        ]
        return conversation


if __name__ == "__main__":
    d = PromptDataset.from_name("adv_behaviors")(
        AdvBehaviorsConfig(
            "adv_behaviors",
            messages_path="data/harmbench_behaviors_text_all.csv",
            targets_path="data/harmbench_targets_text.json",
            seed=1,
            categories=[
                "chemical_biological",
                "illegal",
                "misinformation_disinformation",
                "harmful",
                "harassment_bullying",
                "cybercrime_intrusion",
            ],
        )
    )
