from datasets import load_dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor


class TruthyDPOProcessor(BaseDatasetProcessor):
    """
      SCHEMA = {
      'prompt': 'prompt',
      'chosen': 'chosen',
      'rejected': 'rejected',
    """
    dataset_name = 'jondurbin/truthy-dpo-v0.1'

    def _dataset_to_preference_formatter(self, example):
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"chosen"],
            "rejected": example[f"rejected"],
        }

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.
        """
        dataset = self.get_raw_dataset(split, seed)
        # Inspect the columns in the training split
        print("Original columns in the training split:")
        print(dataset)

        dataset_dimension = dataset.map(self._dataset_to_preference_formatter, num_proc=self.num_proc, remove_columns=dataset.column_names)
        dataset_dict = {}
        dataset_dict["all"] = dataset_dimension
        print("Updated columns in the training split:")
        print(dataset_dict)

        return dataset_dict
