from datasets import load_dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor, DEFAULT_PROMPT_TEMPLATE
from functools import partial

class PKUProcessor30K(BaseDatasetProcessor):
    # Static column mapping dictionaries for different datasets

    dataset_name = "PKU-Alignment/PKU-SafeRLHF-30K"
    dimensions = ["safer", "better"]

    def _dataset_to_preference_formatter(self, example, dimension):
        chosen_idx = example[f"{dimension}_response_id"]
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"response_{chosen_idx}"],
            "rejected": example[f"response_{1-chosen_idx}"],
        }
    
    def get_raw_dataset(self, split, seed):
        # If the dataset includes validation and testing splits, you will need to override this function accordingly
        dataset = load_dataset(self.dataset_name, split="train")
        if self.sanity_check:
            dataset = dataset.select(range(min(len(dataset), 10)))
        # Split into training and temporary (validation + test) sets
        dataset_split = dataset.train_test_split(test_size=0.1, seed=seed)

        # Extract the training and temporary subsets
        train_dataset = dataset_split["train"]
        val_dataset = dataset_split["test"]

        test_dataset = load_dataset(self.dataset_name, split="test")

        if split == "train":
            return train_dataset
        elif split == "validation":
            return val_dataset
        elif split == "test":
            return test_dataset
        else:
            NotImplementedError

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.
        """
        dataset = self.get_raw_dataset(split, seed)
        # Inspect the columns in the training split
        print("Original columns in the data split:")
        print(dataset)
        dataset_dict = {}
        for dimension in self.dimensions:
            if dimension in removed_dimensions:
                print(f"skip dimension {dimension}")
                continue
            transformed_function = partial(self._dataset_to_preference_formatter, dimension=dimension)
            print(dimension)
            dataset_dict[f"{dimension}"] = dataset.map(transformed_function,
                                                       num_proc=self.num_proc,
                                                       remove_columns=dataset.column_names)

        print("Updated columns in the data split:")
        print(dataset_dict)

        return dataset_dict
