from datasets import load_dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor, DEFAULT_PROMPT_TEMPLATE
from functools import partial

class PKUProcessor(BaseDatasetProcessor):
    # Static column mapping dictionaries for different datasets

    dataset_name = "PKU-Alignment/PKU-SafeRLHF-10K"
    dimensions = ["safer", "better"]

    def _dataset_to_preference_formatter(self, example, dimension):
        chosen_idx = example[f"{dimension}_response_id"]
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"response_{chosen_idx}"],
            "rejected": example[f"response_{1-chosen_idx}"],
        }

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.
        """
        dataset = self.get_raw_dataset(split, seed)
        # Inspect the columns in the training split
        print("Original columns in the data split:")
        print(dataset)
        dataset_dict = {}
        for dimension in self.dimensions:
            if dimension in removed_dimensions:
                print(f"skip dimension {dimension}")
                continue
            transformed_function = partial(self._dataset_to_preference_formatter, dimension=dimension)
            print(dimension)
            dataset_dict[f"{dimension}"] = dataset.map(transformed_function,
                                                       num_proc=self.num_proc,
                                                       remove_columns=dataset.column_names)

        print("Updated columns in the data split:")
        print(dataset_dict)

        return dataset_dict
