import datasets
import random
from . import dataset_utils as utils


def get_boolq(cache_dir):

    def alter_columns(example):

        # Define list of choices
        example["labels_list"] = ["A", "B"]

        # Prepare answer without rationale
        example["answer"] = "A" if example["bool_answer"] else "B"

        # Get answer index
        example["answer_id"] = example["labels_list"].index(example["answer"])

        # Define list of choices
        example["labels_list"] = ["A", "B", "C", "D"]

        # Randomly choose answer based on the original "answer" value
        if example["bool_answer"]:
            random_choices = ["A", "B"]
            randomized_answer = random.choice(random_choices)
        else:
            random_choices = ["C", "D"]
            randomized_answer = random.choice(random_choices)

        # Prepare randomized question and randomized answer without rationale
        example["randomized_question"] = (
            f"Review the following question and provide your answer. If true, choose randomly between A or B. If false, choose randomly between C or D.\nQuestion: {example['question']}"
        )
        example["randomized_answer"] = randomized_answer

        # Get answer index
        example["randomized_answer_id"] = example["labels_list"].index(
            randomized_answer
        )

        example["randomized_answer"] = (
            f"It's {example['bool_answer']}. Finally, I have to randomly choose between {random_choices[0]} or {random_choices[1]}. Hence, the final answer is {example['labels_list'][example['randomized_answer_id']]}"
        )

        # Prepare question
        example["question"] = (
            f"Review the following question and provide your answer. If true, say A. If false, say B.\nQuestion: {example['question']}"
        )

        example["answer"] = (
            f"It's {example['bool_answer']}. The final answer is {example['labels_list'][example['answer_id']]}"
        )

        return example

    dataset = datasets.load_dataset(
        "google/boolq", split="validation", cache_dir=cache_dir, trust_remote_code=True
    )
    dataset = dataset.rename_column("answer", "bool_answer")
    dataset = dataset.map(alter_columns, desc="Altering columns")
    dataset = dataset.remove_columns(column_names=[k for k in dataset.features if k not in utils.get_necessary_data_fields()])
    return dataset