import logging
import os
from dataclasses import dataclass
from typing import Sequence

import pandas as pd
import torch

from src.types import Conversation


class PromptDataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config

    def _select_idx(self, config, dataset_len: int):
        """
        Select indices of the dataset based on the config.
        config.idx can be:
            - int: a single index e.g. command line argument:
                        datasets.adv_behaviors.idx=5
            - list[int]: a list of indices e.g. command line argument:
                        datasets.adv_behaviors.idx=[1,2,3]
            - str: a string of the form "list(range(1, 10))" e.g. command line argument:
                        datasets.adv_behaviors.idx="'list(range(1, 10))'"
                        Note: double quotes are required for hydra to parse the string correctly
            - None: all indices are selected
        """

        if config.shuffle:
            torch.manual_seed(config.seed)
            idx = torch.randperm(dataset_len)
        else:
            idx = torch.arange(dataset_len)

        config_idx = config.idx

        # if idx is a string, try to parse it to a sequence
        if isinstance(config_idx, str):
            if config_idx.startswith("list(range("):
                try:
                    config_idx = eval(config_idx, {"__builtins__": None}, {"range": range, "list": list})
                except Exception as e:
                    raise ValueError(f"Could not parse idx string: {config_idx}\n{e}")
            else:
                raise ValueError(f"Could not parse idx string: {config_idx}\nDoes not start with 'list(range('.")

        if isinstance(config_idx, int):
            idx = idx[config_idx : config_idx + 1]
        elif isinstance(config_idx, Sequence):
            logging.info(f"Selecting indices: {config_idx}")
            idx = idx[config_idx]
        elif config_idx is not None:
            raise ValueError(f"Invalid idx: {config.idx}")

        return idx, config_idx

    @classmethod
    def from_name(cls, name):
        match name:
            case "adv_behaviors":
                return AdvBehaviorsDataset
            case _:
                raise ValueError(f"Unknown dataset: {name}")

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx: int) -> Conversation:
        raise NotImplementedError


@dataclass
class AdvBehaviorsConfig:
    name: str
    messages_path: str
    targets_path: str
    categories: list[str]
    seed: int = 0
    idx: list[int] | int | str | None = None
    shuffle: bool = True


class AdvBehaviorsDataset(PromptDataset):
    def __init__(self, config: AdvBehaviorsConfig):
        self.config = config

        logging.info(f"Loading dataset from {config.messages_path}")
        logging.info(f"Current working directory: {os.getcwd()}")

        self.messages = pd.read_csv(config.messages_path, header=0)
        # Ignore copyright-related rows
        self.messages = self.messages[
            self.messages["SemanticCategory"].isin(config.categories)
        ]
        targets = pd.read_json(config.targets_path, typ="series")
        # Merge the CSV data with the JSON data
        self.targets = self.messages[self.messages.columns[-1]].map(targets)
        assert len(self.messages) == len(self.targets), "Mismatched lengths"

        self.idx, self.config_idx = self._select_idx(config, len(self.messages))

        # cut
        self.messages = self.messages.iloc[self.idx].reset_index(drop=True)
        self.targets = self.targets.iloc[self.idx].reset_index(drop=True)

    def __len__(self):
        return len(self.messages)

    def __getitem__(self, idx: int) -> Conversation:
        msg = self.messages.iloc[idx]
        target = self.targets.iloc[idx]
        if isinstance(msg["ContextString"], str):
            content = msg["ContextString"] + "\n\n" + msg["Behavior"]
        else:
            content = msg["Behavior"]
        conversation = [
            {"role": "user", "content": content},
            {"role": "assistant", "content": target}
        ]
        return conversation
