from datasets import load_dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor, DEFAULT_PROMPT_TEMPLATE
from functools import partial
import math


def safe_convert_to_float(string_value):
    if string_value == 'N/A':
        return float('nan')
    return float(string_value)


def ultrafeedback_transform_to_preference(batched_sample, threshold):
    def chosen_id(score1, score2, threshold):
        """
           Compare two scores based on a given threshold and return 1 or 0.

           Parameters:
           - score1: The first score to compare.
           - score2: The second score to compare.
           - threshold: The difference threshold for comparison.

           Returns:
           - 1 if the absolute difference between score1 and score2 is greater or equal to the threshold
             and score1 is higher than score2.
           - -1 if the absolute difference is smaller than to the threshold,
           -  0 if score2 is higher than score1.
           """
        assert threshold > 0, "The threshold must be a positive number."
        # Convert scores to floats
        score1_float = safe_convert_to_float(score1)
        score2_float = safe_convert_to_float(score2)

        # Check if either score could not be converted
        if math.isnan(score1_float) or math.isnan(score2_float):
            return -9999  # Return a sentinel value indicating invalid input

        difference = abs(score1_float - score2_float)

        if difference >= threshold:
            return 0 if score1_float > score2_float else 1
        else:
            return -1

    finegrained_dimensions = ("instruction_following", "honesty", "truthfulness", "helpfulness")
    dimensions = finegrained_dimensions + ("overall",)

    new_batched_sample = {
        "prompt": [],
        "response_0": [],
        "response_1": [],
        **{f"{dimension}_chosen_id": [] for dimension in dimensions}
    }
    for instruction, completions in zip(batched_sample["instruction"], batched_sample["completions"]):
        n_responses = len(completions)

        for j in range(n_responses):
            for k in range(j+1, n_responses):
                new_batched_sample["prompt"].append(instruction)
                new_batched_sample["response_0"].append(completions[j]['response'])
                new_batched_sample["response_1"].append(completions[k]['response'])
                new_batched_sample["overall_chosen_id"].append(
                    chosen_id(
                        completions[j]["overall_score"],
                        completions[k]["overall_score"],
                        threshold
                    )
                )
                for dimension in finegrained_dimensions:
                    new_batched_sample[f"{dimension}_chosen_id"].append(
                        chosen_id(
                            completions[j]["annotations"][dimension]["Rating"],
                            completions[k]["annotations"][dimension]["Rating"],
                            threshold
                        )
                    )

    return new_batched_sample


# Define the filtering function
def filter_row(example, dimensions):
    for dim in dimensions:
        value = example.get(f"{dim}_chosen_id")
        if (value != 0) and (value != 1):
            return False  # Exclude row if -1 or NaN is found
    return True


class UltraFeedbackRDPProcessor(BaseDatasetProcessor):
    # Static dictionaries for different datasets
    dataset_name = "openbmb/UltraFeedback"
    dimensions = {"instruction_following", "honesty", "truthfulness", "helpfulness", "overall"}
    '''
    SCHEMA = {
        'prompt': 'prompt',
        'chosen': 'response_0',
        'rejected': 'response_1',
    }
    # Dynamically create mappings for each dimension
    SCHEMA.update({
        f'{dimension}_chosen_id': f'chosen_id_dim_{i + 1}'
        for i, dimension in enumerate(dimensions)
    })
    '''
    # update threshold can change the dataset size
    def __init__(self,
                 prompt_template=DEFAULT_PROMPT_TEMPLATE,
                 num_proc=4, sanity_check=False, threshold=2):
        super().__init__(num_proc, sanity_check, prompt_template)
        self.threshold = threshold
        print(f"UltraFeedbackRDPProcessor initialized with threshold {self.threshold}")

    def _dataset_to_preference_formatter(self, example, dimension):
        chosen_id = example[f"{dimension}_chosen_id"]
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"response_{chosen_id}"],
            "rejected": example[f"response_{1-chosen_id}"],
        }

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.

        Parameters:
            dataset_name (str): The name of the dataset to load.
            split (str): The split of the dataset to load (e.g., 'train', 'test').

        Returns:
            dataset: The processed dataset.
        """
        dataset = self.get_raw_dataset(split, seed)
        # Inspect the columns in the training split
        print("Original columns in the training split:")
        print(dataset)
        original_columns = dataset.column_names
        transformed_function = partial(ultrafeedback_transform_to_preference, threshold=self.threshold)

        dataset = dataset.map(
            transformed_function,
            batched=True,
            num_proc=self.num_proc,
            remove_columns=original_columns,
        )
        print("mapping raw dataset to preference...")
        print(dataset)

        #for i in range(len(dataset)):
        #    print(dataset[i])

        # filter_row_function = partial(filter_row, dimensions=self.dimensions)
        # dataset = dataset.filter(lambda example: filter_row_function(example))
        dataset_dict = {}
        for dimension in self.dimensions:
            if dimension in removed_dimensions:
                print(f"skip dimension {dimension}")
                continue
            print("filtering preference...")
            fileter_dataset = dataset.filter(lambda x: x[f"{dimension}_chosen_id"] not in [-1, -9999])
            print(fileter_dataset)
            transformed_function = partial(self._dataset_to_preference_formatter, dimension=dimension)
            dataset_dict[dimension] = fileter_dataset.map(transformed_function, 
                                                          num_proc=self.num_proc, 
                                                          remove_columns=dataset.column_names)
        print("Updated columns in the training split:")
        print(dataset_dict)
        return dataset_dict
