import random
import os
import json
from typing import List, Dict, Callable, Any, Optional

from datasets import (
    Dataset,
    DatasetDict,
    load_dataset,
    load_from_disk,
    concatenate_datasets,
)
from datasets.builder import DatasetGenerationError

# set the random seed
random.seed(42)

def extract_boxed_answer(text: str) -> str | None:
    def find_matching_brace(s: str, start: int) -> int:
        count = 1
        i = start
        while i < len(s) and count > 0:
            if s[i] == '{':
                count += 1
            elif s[i] == '}':
                count -= 1
            i += 1
        return i - 1 if count == 0 else -1

    # Find \boxed{
    boxed_start = text.find('\\boxed{')
    if boxed_start == -1:
        return text
    # Find the content between the braces
    content_start = boxed_start + 7  # len('\\boxed{')
    closing_brace = find_matching_brace(text, content_start)
    
    if closing_brace == -1:
        return text
    
    return text[content_start:closing_brace]

def strip_non_numeric(text: str) -> str:
    return "".join(c for c in text if c.isdigit() or c == '.')

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# def get_preprocess_fn(name: str) -> Callable[[Dict], Dict]: 
#     if name == "aime2024":
#         def preprocess_aime2024(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["problem"],
#                 "answer": str(int(x["answer"])),
#                 "task": "math"
#             }
#         return preprocess_aime2024
#     elif name == "aime2025":
#         def preprocess_aime2025(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["question"],
#                 "answer": strip_non_numeric(x["answer"]),
#                 "task": "math"
#             }
#         return preprocess_aime2025
#     elif name == "amc2023":
#         def preprocess_amc2023(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["problem"],
#                 "answer": x["answer"],
#                 "task": "math"
#             }
#         return preprocess_amc2023
#     elif name in ["gpqa_diamond", "gpqa_main"]:
#         def preprocess_gpqa(x: Dict[str, Any]) -> Dict[str, Any]:
#             q = x["Question"]
#             letters = ["A", "B", "C", "D"]
#             random.shuffle(letters)
#             itos = {k: v for k, v in enumerate(letters)} 
#             ans = {
#                 itos[0]: x["Correct Answer"],
#                 itos[1]: x["Incorrect Answer 1"],
#                 itos[2]: x["Incorrect Answer 2"],
#                 itos[3]: x["Incorrect Answer 3"]
#             }
#             question = f"Question: {q}\n\n"
#             question += f"A: {ans['A']}\n"
#             question += f"B: {ans['B']}\n"
#             question += f"C: {ans['C']}\n"
#             question += f"D: {ans['D']}"

#             return {
#                 "question": question, 
#                 "answer": itos[0],
#                 "task": "mc"
#             }
#         return preprocess_gpqa
#     elif name == "gsm8k":
#         def preprocess_gsm8k(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["question"],
#                 "answer": extract_hash_answer(x["answer"]),
#                 "task": "math"
#             }
#         return preprocess_gsm8k
#     elif name == "math":
#         def preprocess_math(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["problem"],
#                 "answer": extract_boxed_answer(x["solution"]),
#                 "task": "math"
#             }
#         return preprocess_math
#     elif name == "math500":
#         def preprocess_math500(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["problem"],
#                 "answer": x["answer"],
#                 "task": "math"
#             }
#         return preprocess_math500
#     elif name == "mmlu":
#         mmlu_map = ["A", "B", "C", "D"]
#         def preprocess_mmlu(x: Dict[str, Any]) -> Dict[str, Any]:
#             options = x["choices"]
#             answer = x["answer"]
#             question = f"Question: {x["question"]}\n"
#             for i, option in enumerate(options):
#                 question += f"\n{mmlu_map[i]}: {option}"
#             return {
#                 "question": question,
#                 "temp_answer": mmlu_map[answer],
#                 "task": "mc"
#             }
#         return preprocess_mmlu
#     elif name == "mmlu_pro":
#         mmlu_map = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
#         def preprocess_mmlu(x: Dict[str, Any]) -> Dict[str, Any]:
#             options = x["options"]
#             answer = x["answer"]
#             question = f"Question: {x["question"]}\n"
#             for i, option in enumerate(options):
#                 question += f"\n{mmlu_map[i]}: {option}"
#             return {
#                 "question": question,
#                 "answer": answer,
#                 "task": "mc"
#             }
#         return preprocess_mmlu
#     elif name == "openbookqa":
#         def preprocess_openbookqa(x: Dict[str, Any]) -> Dict[str, Any]:
#             choices_texts = x['choices']['text']
#             choices_labels = x['choices']['label']
            
#             formatted_choices = []
#             for i in range(len(choices_labels)):
#                 formatted_choices.append(f"{choices_labels[i]}. {choices_texts[i]}")
            
#             question = f"Question: {x['question_stem']}\n\nChoices:\n" + "\n".join(formatted_choices)
#             return {
#                 "question": question,
#                 "answer": x["answerKey"],
#                 "task": "mc"
#             }
#         return preprocess_openbookqa
#     elif name in ["openrs", "openrs_easy", "openrs_hard"]:
#         def preprocess_openrs(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["problem"],
#                 "answer": x["answer"],
#                 "task": "math"
#             }
#         return preprocess_openrs
#     elif name == "prime_code":
#         def preprocess_prime_code(x: Dict[str, Any]) -> Dict[str, Any]:
#             return {
#                 "question": x["prompt"],
#                 "answer": x["verification_info"],
#                 "task": "code"
#             }
#         return preprocess_prime_code
#     else:
#         raise ValueError(f"Dataset {name} not supported for preprocess_dataset.")

# def preprocess_dataset(name: str = "gsm8k",
#                        split: str | None = None,
#                        n: int | None = None,
#                        seed: int = 0) -> Dataset:
#     if name == "aime2024":
#         if split is None:
#             split = "train"
#         dataset = load_dataset("HuggingFaceH4/aime_2024")[split] # type: ignore
#     elif name == "aime2025":
#         if split is None:
#             split = "test"
#         aime_i = load_dataset("opencompass/AIME2025", "AIME2025-I")[split] # type: ignore
#         aime_ii = load_dataset("opencompass/AIME2025", "AIME2025-II")[split] # type: ignore
#         dataset = concatenate_datasets([aime_i, aime_ii]) # type: ignore
#     elif name == "amc2023":
#         if split is None:
#             split = "train"
#         dataset = load_dataset("knoveleng/AMC-23")[split] # type: ignore
#     elif name == "gpqa_diamond":
#         if split is None:
#             split = "train"
#         dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond")[split] # type: ignore
#     elif name == "gpqa_main":
#         if split is None:
#             split = "train"
#         dataset = load_dataset("Idavidrein/gpqa", "gpqa_main")[split] # type: ignore
#     elif name == "gsm8k":
#         if split is None:
#             split = "test"
#         dataset: Dataset = load_dataset("openai/gsm8k", "main")[split] # type: ignore
#     elif name == "math":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("chiayewken/competition_math")[split] # type: ignore
#     elif name == "math500":
#         if split is None:
#             split = "test"
#         dataset: Dataset = load_dataset("HuggingFaceH4/MATH-500")[split] # type: ignore
#     elif name == "mmlu":
#         if split is None:
#             split = "dev"
#         dataset = load_dataset("cais/mmlu", "all")[split] # type: ignore
#     elif name == "mmlu_pro":
#         if split is None:
#             split = "validation"
#         dataset = load_dataset("TIGER-Lab/MMLU-Pro")[split] # type: ignore
#     elif name == "openbookqa":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("allenai/openbookqa", "main")[split] # type: ignore
#     elif name == "openrs":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("knoveleng/open-rs")[split] # type: ignore
#     elif name == "openrs_easy":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("knoveleng/open-rs")[split] # type: ignore
#         dataset = dataset.filter(lambda x: x["level"] == "Easy") # type: ignore
#     elif name == "openrs_hard":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("knoveleng/open-rs")[split] # type: ignore
#         dataset = dataset.filter(lambda x: x["level"] == "Hard") # type: ignore
#     elif name == "prime_code":
#         if split is None:
#             split = "train"
#         dataset: Dataset = load_dataset("PrimeIntellect/verifiable-coding-problems")[split] # type: ignore
#         dataset = dataset.filter(lambda x: x['prompt'].startswith("Solve the following coding problem using the programming language python:")) # type: ignore
#     else:
#         raise ValueError(f"Dataset {name} not supported for preprocess_dataset. \
# Please ensure that the dataset is formatted with 'prompt' (str) and 'answer' (str) keys.")
    
#     preprocess_fn = get_preprocess_fn(name)
#     if n is not None and n > 0:
#         dataset = dataset.shuffle(seed=seed).select(range(n)) # type: ignore
#     dataset = dataset.map(preprocess_fn, num_proc=10, remove_columns=dataset.column_names) # type: ignore
#     if "temp_answer" in dataset.column_names:
#         dataset = dataset.rename_column("temp_answer", "answer")
#     return dataset

def format_prompt(prompt: str,
                  system_prompt: str | None = None,
                  few_shot: List[Dict[str, str]] | None = None,
                  fewshot_prob: float = 1.0) -> List[Dict[str, str]]:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    if few_shot and random.random() < fewshot_prob:
        messages.extend(few_shot)
    messages.append({"role": "user", "content": prompt})
    return messages


def format_dataset(dataset: Dataset,
                   system_prompt: str | None = None,
                   few_shot: List[Dict[str, str]] | None = None,
                   fewshot_prob: float = 1.0,
                   question_key: str = "question",
                   answer_key: str = "answer",
                   ) -> Dataset:
    return dataset.map(lambda x: {
        "prompt": format_prompt(x[question_key], system_prompt, few_shot, fewshot_prob),
        "answer": x[answer_key]
    }, num_proc=10)



def select_rule_indices(example: Dict[str, Any],
                        max_turns: int = 4,
                        total_rules: int = 10,
                       ) -> list[int]:
    """
    Select the rule indices for the reverse COT.
    """

    rule_indices = list(range(total_rules))
    # randomize the rule indices
    random.shuffle(rule_indices)
    # select the first max_turns rule indices
    rule_indices = rule_indices[:max_turns]
    
    return rule_indices


def format_prompt_for_distill_cot(example: Dict[str, Any],
                                  system_prompt: str | None = None,
                                  include_original_question: bool = True,
                                  include_augmented_questions: bool = True,
                                  task: str = "strategy_qa",
                                  cot_response_key: str = "cot_solution",
                                  augmented_question_key: str = "augmented_questions",
                                  augmented_solutions_key: str = "augmented_solutions",
                                  augmented_original_solutions_key: str = "augmented_original_solutions"
                                  ) -> List[Dict[str, Any]]:
    """
    Format the prompt for distillation with CoT. We consider both the original question and the augmented questions and their corresponding CoT solutions.

    For the original question, if the CoT solution is provided (which means the original question's CoT rationale is correct), we will use this CoT solution as the original question's CoT solution. Otherwise, if we have any augmented original CoT solutions, we will sample one of them as the original question's CoT solution.

    For the augmented questions, we will use all of the augmented questions and their corresponding CoT solutions.
    """
    distill_examples = []
    
    # Get the id field if it exists, otherwise use None
    example_id = None
    if task == "strategy_qa":
        example_id = example.get("qid")
    elif task in ["arc_challenge", "date", "commonsense_qa", "table_mwp", "gsm8k", "math"]:
        example_id = example.get("id")
    elif task == "anli":
        example_id = example.get("uid")
    
    # Construct the CoT user prompt based on the task
    if task == "strategy_qa":
        original_question = example["question"]
        original_cot_solution = example[cot_response_key]
        augmented_questions = example[augmented_question_key]
        augmented_solutions = example[augmented_solutions_key]
        augmented_original_solutions = example[augmented_original_solutions_key]
        
        # For the original question.
        if include_original_question:
            if original_cot_solution:
                example_dict = {
                    "question": original_question,
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": "Yes or No: "+original_question.strip()+"\n\n"+"Please reason step by step, and conclude with either \"Yes\" or \"No\"."},
                        {"role": "assistant", "content": original_cot_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
            else:
                if len(augmented_original_solutions) > 0:
                    # Sample one of the augmented original CoT solutions as the original question's CoT solution
                    original_cot_solution = augmented_original_solutions[random.randint(0, len(augmented_original_solutions) - 1)]
                    example_dict = {
                        "question": original_question,
                        "answer": example["answer"],
                        "messages": [
                            *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                            {"role": "user", "content": "Yes or No: "+original_question.strip()+"\n\n"+"Please reason step by step, and conclude with either \"Yes\" or \"No\"."},
                            {"role": "assistant", "content": original_cot_solution.strip()}
                        ]
                    }
                    if example_id is not None:
                        example_dict["id"] = example_id
                    distill_examples.append(example_dict)
                else:
                    pass
                    # raise ValueError("No original CoT solution provided for the original question.")

        # For the augmented questions.
        if include_augmented_questions:
            for augmented_question, augmented_solution in zip(augmented_questions, augmented_solutions):    
                example_dict = {
                    "question": augmented_question,
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": augmented_question.strip()},
                        {"role": "assistant", "content": augmented_solution.strip()}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)

    elif task == "arc_challenge" or task == "date" or task == "commonsense_qa":
        num_choices_dict = {2: "two", 3: "three", 4: "four", 5: "five", 6: "six"}
        choices_key_dict = {2: "A and B", 3: "A, B, and C", 4: "A, B, C, and D", 5: "A, B, C, D, and E", 6: "A, B, C, D, E, and F"}
        answer_key_dict = {2: "A or B", 3: "A, B, or C", 4: "A, B, C, or D", 5: "A, B, C, D, or E", 6: "A, B, C, D, E, or F"}
        choices = example["choices"]["text"]
        choicesKey = example["choices"]["label"]
        choices_str = ""
        for i in range(len(choices)):
            choices_str += choicesKey[i] + ". " + choices[i].strip() + "\n"
        
        original_question = example["question"]
        original_cot_solution = example[cot_response_key]
        augmented_questions = example[augmented_question_key]
        augmented_solutions = example[augmented_solutions_key]
        augmented_original_solutions = example[augmented_original_solutions_key]
        
        # For the original question.
        if include_original_question:
            if original_cot_solution:
                example_dict = {
                    "question": original_question,
                    "choices": example["choices"],
                    "answer": example["answerKey"],
                    "answer_text": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": "Given the following question and "+num_choices_dict[len(choices)]+" candidate answers ("+choices_key_dict[len(choices)]+"), choose the best answer.\nQuestion: "+original_question+"\n"+choices_str+"Please reason step by step, and conclude with your choice. Your response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of "+answer_key_dict[len(choices)]+"."},
                        {"role": "assistant", "content": original_cot_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
            else:
                if len(augmented_original_solutions) > 0:
                    original_cot_solution = augmented_original_solutions[random.randint(0, len(augmented_original_solutions) - 1)]
                    example_dict = {
                        "question": original_question,
                        "choices": example["choices"],
                        "answer": example["answerKey"],
                        "answer_text": example["answer"],
                        "messages": [
                            *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                            {"role": "user", "content": "Given the following question and "+num_choices_dict[len(choices)]+" candidate answers ("+choices_key_dict[len(choices)]+"), choose the best answer.\nQuestion: "+original_question+"\n"+choices_str+"Please reason step by step, and conclude with your choice. Your response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of "+answer_key_dict[len(choices)]+"."},
                            {"role": "assistant", "content": original_cot_solution}
                        ]
                    }
                    if example_id is not None:
                        example_dict["id"] = example_id
                    distill_examples.append(example_dict)
                else:
                    pass
                    # raise ValueError("No original CoT solution provided for the original question.")

        # For the augmented questions.
        if include_augmented_questions:
            for augmented_question, augmented_solution in zip(augmented_questions, augmented_solutions):    
                example_dict = {
                    "question": augmented_question,
                    "choices": example["choices"],
                    "answer": example["answerKey"],
                    "answer_text": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": augmented_question},
                        {"role": "assistant", "content": augmented_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
    
    elif task == "table_mwp":
        table_title = example["table_title"] if "table_title" in example else ""
        table_content = example["table"]
        table_unit = example["unit"] if "unit" in example else ""
        choices_text = example["choices"]       # If this is None, it means the question is not multiple choice.
        # Construct the table title and content
        if table_title is not None and table_title != "":
            table_prompt = "Read the following table regarding \""+table_title+"\" and then answer a question:\n\n"+table_content+"\n\n"
        else:
            table_prompt = "Read the following table and answer a question:\n\n"+table_content+"\n\n"
        # Concstruct the question's unit if provided
        if table_unit is not None and table_unit != "":
            unit_prompt = " (Unit: "+table_unit+")"
        else:
            unit_prompt = ""

        original_question = example["question"]
        original_cot_solution = example[cot_response_key]
        augmented_questions = example[augmented_question_key]
        augmented_solutions = example[augmented_solutions_key]
        augmented_original_solutions = example[augmented_original_solutions_key]

        if choices_text is None:
            original_question_user_prompt = table_prompt+"Question: "+original_question+unit_prompt+"\n\n"+"Please reason step by step, and put your final answer within \\boxed{}."
        else:
            choices_str = ""
            for i in range(len(choices_text)):
                choices_str += chr(ord('A') + i) + ". " + choices_text[i]
            original_question_user_prompt = table_prompt+"Question: "+original_question+unit_prompt+"\nOptions: "+choices_str+"\n\n"+"Please reason step by step, and put your final answer within \\boxed{}."

        # For the original question.
        if include_original_question:
            if original_cot_solution:
                example_dict = {
                    "question": original_question,
                    "choices": example["choices"],
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": original_question_user_prompt},
                        {"role": "assistant", "content": original_cot_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
            else:
                if len(augmented_original_solutions) > 0:
                    original_cot_solution = augmented_original_solutions[random.randint(0, len(augmented_original_solutions) - 1)]
                    example_dict = {
                        "question": original_question,
                        "choices": example["choices"],
                        "answer": example["answer"],
                        "messages": [
                            *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                            {"role": "user", "content": original_question_user_prompt},
                            {"role": "assistant", "content": original_cot_solution}
                        ]
                    }
                    if example_id is not None:
                        example_dict["id"] = example_id
                    distill_examples.append(example_dict)
                else:
                    pass
                    # raise ValueError("No original CoT solution provided for the original question.")

        # For the augmented questions.
        if include_augmented_questions:
            for augmented_question, augmented_solution in zip(augmented_questions, augmented_solutions):    
                example_dict = {
                    "question": augmented_question,
                    "choices": example["choices"],
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": augmented_question},
                        {"role": "assistant", "content": augmented_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
    
    elif task == "gsm8k" or task == "math":
        original_question = example["question"]
        original_cot_solution = example[cot_response_key]
        augmented_questions = example[augmented_question_key]
        augmented_solutions = example[augmented_solutions_key]
        augmented_original_solutions = example[augmented_original_solutions_key]

        # For the original question.
        if include_original_question:
            if original_cot_solution:
                example_dict = {
                    "question": original_question,
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": original_question+"\n\n"+"Please reason step by step, and put your final answer within \\boxed{}."},
                        {"role": "assistant", "content": original_cot_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
            else:
                if len(augmented_original_solutions) > 0:
                    original_cot_solution = augmented_original_solutions[random.randint(0, len(augmented_original_solutions) - 1)]
                    example_dict = {
                        "question": original_question,
                        "answer": example["answer"],
                        "messages": [
                            *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                            {"role": "user", "content": original_question+"\n\n"+"Please reason step by step, and put your final answer within \\boxed{}."},
                            {"role": "assistant", "content": original_cot_solution}
                        ]
                    }
                    if example_id is not None:
                        example_dict["id"] = example_id
                    distill_examples.append(example_dict)
                else:
                    pass
                    # raise ValueError("No original CoT solution provided for the original question.")

        # For the augmented questions.
        if include_augmented_questions:
            for augmented_question, augmented_solution in zip(augmented_questions, augmented_solutions):    
                example_dict = {
                    "question": augmented_question,
                    "answer": example["answer"],
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": augmented_question},
                        {"role": "assistant", "content": augmented_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
    
    elif task == "anli":
        premise = example["premise"]
        hypothesis = example["hypothesis"]
        answer = example["label"]
        if answer == "entailment":
            answer = "True"
        elif answer == "contradiction":
            answer = "False"
        else:
            answer = "Neither"

        original_cot_solution = example[cot_response_key]
        augmented_questions = example[augmented_question_key]
        augmented_solutions = example[augmented_solutions_key]
        augmented_original_solutions = example[augmented_original_solutions_key]

        # For the original question.
        if include_original_question:
            if original_cot_solution:
                example_dict = {
                    "premise": premise,
                    "hypothesis": hypothesis,
                    "answer": answer,
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": "Given that \""+premise+"\"\nQuestion: "+hypothesis+"\nTrue, False, or Neither?\n\n"+"Please reason step by step, and conclude with your final answer. Your response should end with \"The best answer is [the_answer_choice]\" where [the_answer_choice] is one of \"True\", \"False\", or \"Neither\"."},
                        {"role": "assistant", "content": original_cot_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
            else:
                if len(augmented_original_solutions) > 0:
                    original_cot_solution = augmented_original_solutions[random.randint(0, len(augmented_original_solutions) - 1)]
                    example_dict = {
                        "premise": premise,
                        "hypothesis": hypothesis,
                        "answer": answer,
                        "messages": [
                            *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                            {"role": "user", "content": "Given that \""+premise+"\"\nQuestion: "+hypothesis+"\nTrue, False, or Neither?\n\n"+"Please reason step by step, and conclude with your final answer. Your response should end with \"The best answer is [the_answer_choice]\" where [the_answer_choice] is one of \"True\", \"False\", or \"Neither\"."},
                            {"role": "assistant", "content": original_cot_solution}
                        ]
                    }
                    if example_id is not None:
                        example_dict["id"] = example_id
                    distill_examples.append(example_dict)
                else:
                    pass
                    # raise ValueError("No original CoT solution provided for the original question.")

        # For the augmented questions.
        if include_augmented_questions:
            for augmented_question, augmented_solution in zip(augmented_questions, augmented_solutions):    
                example_dict = {
                    "premise": premise,
                    "hypothesis": hypothesis,
                    "answer": answer,
                    "messages": [
                        *([{"role": "system", "content": system_prompt}] if system_prompt is not None else []),
                        {"role": "user", "content": "Given that \""+premise+"\"\nQuestion: "+hypothesis+" True, False, or Neither?\n\n"+"Please reason step by step, and conclude with your final answer."},
                        {"role": "assistant", "content": augmented_solution}
                    ]
                }
                if example_id is not None:
                    example_dict["id"] = example_id
                distill_examples.append(example_dict)
    
    else:
        raise ValueError(f"Task {task} is not supported for reverse augmentation")

    return distill_examples


def format_sft_dataset_for_distill_cot(dataset: Dataset,
                                       system_prompt: str | None = None,
                                       include_original_question: bool = True,
                                       include_augmented_questions: bool = True,
                                       task: str = "strategy_qa",
                                       cot_response_key: str = "cot_solution",
                                       augmented_question_key: str = "augmented_questions",
                                       augmented_solutions_key: str = "augmented_solutions",
                                       augmented_original_solutions_key: str = "augmented_original_solutions",
                                       ) -> Dataset:
    """
    Format the dataset for supervised fine-tuning (SFT) with reverse COT.
    """
    # Process all examples to get the list of formatted examples
    all_examples = []
    for example in dataset:
        # Use a fixed seed based on example id or hash to ensure reproducibility
        # example_seed = hash(str(example.get("id", str(example)))) % 10000 + 42
        formatted_examples = format_prompt_for_distill_cot(
            example=example, 
            system_prompt=system_prompt, 
            include_original_question=include_original_question,
            include_augmented_questions=include_augmented_questions,
            task=task,
            cot_response_key=cot_response_key, 
            augmented_question_key=augmented_question_key,
            augmented_solutions_key=augmented_solutions_key,
            augmented_original_solutions_key=augmented_original_solutions_key
        )
        all_examples.extend(formatted_examples)
    
    # Create a new dataset from the list of examples and shuffle it
    return Dataset.from_list(all_examples).shuffle(seed=42)


def format_prompt_for_reverse_cot(example: Dict[str, Any],
                                  support_system_prompt: bool = True,
                                  system_prompt: str | None = None,
                                  task: str = "strategy_qa",
                                  turn_idx: int = 0,
                                  max_turns: int = 4,
                                  total_rules: int = 10,
                                  rule_indices: list[int] = None,
                                  provide_original_question_first: bool = False,
                                  ) -> Dict[str, Any]:
    """
    Generate the response for the original question using iterative CoT. We first give the background information that requires to answer the original question (for example, the table for the TableMWP task), and then ask the LLM to first answer the augmented question iteratively. Then at the last round, we ask the LLM to answer the original question given all the previous rounds as the context.
    """

    if rule_indices is None:
        # If no rule indices are provided, we use all the rules indices and then randomize them, so that we can sample max_iters from them
        rule_indices = list(range(total_rules))
        # randomize the rule indices
        random.shuffle(rule_indices)
    
    # Remember that turn_idx is 0-indexed, so when turn_idx == max_turns, we are at the last turn to answer the original question
    if turn_idx == max_turns:
        # If we are at the last turn, we don't need to ask more augmented questions -- we will just answer the original question
        rule_idx = -1
    else:
        rule_idx = rule_indices[turn_idx]

    # prompt_answer_augmentedQ = "Before answering the original question, let's break down the problem by addressing some related questions first. For each augmented question:\n\n1. Focus on answering ONLY the augmented question presented\n2. Consider all relevant information from the original question in your reasoning\n3. Provide a clear, step-by-step explanation\n4. Do NOT answer the original question until specifically asked to do so.\n\n"
    
    if provide_original_question_first:
        system_prompt = (
            "You are a helpful assistant trained to reason through explanatory inversion (EI). "
            "You will first be given the original question as background information, followed by one EI question at a time. "
            "Each EI question probes a different aspect of reasoning relevant to the original problem. "
            "For each EI question:\n\n"
            "1. Focus ONLY on the EI question presented\n"
            "2. Use relevant information from the original question to support your reasoning\n"
            "3. Provide a clear, step-by-step explanation\n"
            "4. Do NOT attempt to answer the original question until it is explicitly asked.\n\n"
            "After all EI questions are answered, the original question will be presented again. "
            "At that point, use the insights from your earlier responses to generate a well-reasoned final answer."
        )

    else:
        system_prompt = (
            "You are a helpful assistant trained to reason through explanatory inversion (EI). "
            "You will be given one EI question at a time. "
            "Each EI question probes part of the reasoning behind a hidden original question. "
            "Answer each EI question carefully and thoroughly. "
            "Do NOT reference or attempt to answer the original question until it is explicitly given. "
            "After all EI questions are answered, the original question will be provided. "
            "Use insights from your previous answers to generate a final, well-reasoned response."
        )
    
    # prompt_answer_originalQ = "Now that we've completed the explanatory inversion (EI) questions, let's return to the original problem. Using the insights from your previous answers, provide a well-reasoned response to the original question:\n\n"
    prompt_answer_originalQ = (
        "Now that we've completed the explanatory inversion (EI) questions, let's return to the original problem.\n"
        "Use the insights from your previous answers to inform your reasoning.\n"
        "You MUST first think about the reasoning process as an internal monologue, and then provide the answer.\n"
        "The reasoning MUST be enclosed within <think> and </think> tags.\n\n"
    )

    if "metadata" in example:
        metadata = example["metadata"]
    else:
        metadata = example
    
    # Construct the CoT user prompt based on the task
    if task == "strategy_qa":
        if turn_idx == 0:
            if provide_original_question_first:
                user_prompt = (
                    "Original Question:\nYes or No: "+metadata["question"].strip()+"\n\n"
                    + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                    + f"EI Question {turn_idx+1}:\n"
                    + metadata["augmented_questions"][rule_idx]
                )
            else:
                user_prompt = (
                    f"EI Question {turn_idx+1}:\n"
                    + metadata["augmented_questions"][rule_idx]
                )
        elif turn_idx == max_turns:
            # user_prompt = (
            #     prompt_answer_originalQ
            #     + "Original Question:\n"
            #     + "Yes or No: " + metadata["question"].strip() + "\n\n"
            #     + 'Please reason step by step and conclude with either "Yes" or "No."'
            # )
            user_prompt = (
                prompt_answer_originalQ
                + "Original Question:\n"
                + "Yes or No: " + metadata["question"].strip() + "\n\n"
                + 'Please conclude with either "Yes" or "No" after the <think> block.'
            )
        else:
            user_prompt = (
                "Let's continue by exploring another related question. Remember, do NOT answer the original question yet.\n\n"
                + f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            )
    
    elif task == "arc_challenge" or task == "date" or task == "commonsense_qa":
        num_choices_dict = {2: "two", 3: "three", 4: "four", 5: "five", 6: "six"}
        choices_key_dict = {2: "A and B", 3: "A, B, and C", 4: "A, B, C, and D", 5: "A, B, C, D, and E", 6: "A, B, C, D, E, and F"}
        answer_key_dict = {2: "A or B", 3: "A, B, or C", 4: "A, B, C, or D", 5: "A, B, C, D, or E", 6: "A, B, C, D, E, or F"}
        choices = metadata["choices"]["text"]
        choicesKey = metadata["choices"]["label"]
        choices_str = ""
        for i in range(len(choices)):
            choices_str += choicesKey[i] + ". " + choices[i].strip() + "\n"
        
        if turn_idx == 0:
            if provide_original_question_first:
                user_prompt = (
                    "Original Question:\n"+metadata["question"]+"\n"+choices_str+"\n\n"
                    + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                    + f"EI Question {turn_idx+1}:\n"
                    + metadata["augmented_questions"][rule_idx]
                )
            else:
                user_prompt = f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
        
        elif turn_idx == max_turns:
            user_prompt = (
                prompt_answer_originalQ
                + "Original Question:\n"
                + "Given the following question and "+num_choices_dict[len(choices)]+" candidate answers ("+choices_key_dict[len(choices)]+"), choose the best answer.\n"
                + "Question: " + metadata["question"] + "\n"
                + choices_str
                + "Please reason step by step inside <think>...</think>, and conclude with your choice.\n"
                + f'Your response must end with: "The best answer is [the_answer_letter]", '
                + f"where [the_answer_letter] is one of {answer_key_dict[len(choices)]}."
            )
        else:
            user_prompt = (
                "Let's continue by exploring another related question. Remember, do NOT answer the original question yet.\n\n"
                + f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            )
    
    elif task == "table_mwp":
        table_title = metadata["table_title"] if "table_title" in metadata else ""
        table_content = metadata["table"]
        table_unit = metadata["unit"] if "unit" in metadata else ""
        choices = metadata["choices"]["text"]
        # Construct the table title and content
        if table_title is not None and table_title != "":
            table_prompt = "Read the following table regarding \""+table_title+"\" and then answer a question:\n\n"+table_content+"\n\n"
        else:
            table_prompt = "Read the following table and answer a question:\n\n"+table_content+"\n\n"
        # Concstruct the question's unit if provided
        if table_unit is not None and table_unit != "":
            unit_prompt = " (Unit: "+table_unit+")"
        else:
            unit_prompt = ""
        
        if turn_idx == 0:
            if choices is None:
                if provide_original_question_first:
                    user_prompt = (
                        "Original Question:\n"+table_prompt+"Question: "+metadata["question"]+unit_prompt+"\n\n"
                        + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                        + f"EI Question {turn_idx+1}:\n"
                        + metadata["augmented_questions"][rule_idx]
                    )
                else:
                    user_prompt = f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            else:
                choices_str = ""
                choicesKey = ['A', 'B', 'C', 'D', 'E', 'F']
                for i in range(len(choices)):
                    choices_str += choicesKey[i] + ". " + choices[i]
                if provide_original_question_first:
                    user_prompt = (
                        "Original Question:\n"+table_prompt+"Question: "+metadata["question"]+unit_prompt+"\nOptions: "+choices_str+"\n\n"
                        + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                        + f"EI Question {turn_idx+1}:\n"
                        + metadata["augmented_questions"][rule_idx]
                    )
                else:
                    user_prompt = f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
        elif turn_idx == max_turns: 
            if choices is None:
                user_prompt = (
                    prompt_answer_originalQ
                    + "Original Question:\n"
                    + table_prompt
                    + "Question: " + metadata["question"] + unit_prompt + "\n\n"
                    + "Please reason step by step inside <think>...</think>, and put your final answer within \\boxed{}."
                )
            else:
                choices_str = ""
                choicesKey = ['A', 'B', 'C', 'D', 'E', 'F']
                for i in range(len(choices)):
                    choices_str += choicesKey[i] + ". " + choices[i]
                user_prompt = (
                    prompt_answer_originalQ
                    + "Original Question:\n"
                    + table_prompt
                    + "Question: " + metadata["question"] + unit_prompt + "\n"
                    + "Options: " + choices_str + "\n\n"
                    + "Please reason step by step inside <think>...</think>, and put your final answer within \\boxed{}."
                )
        else:
            user_prompt = (
                "Let's continue by exploring another related question. Remember, do NOT answer the original question yet.\n\n"
                + f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            )
    
    elif task == "gsm8k" or task == "math":
        if turn_idx == 0:
            if provide_original_question_first:
                user_prompt = (
                    "Original Question:\n"+metadata["question"]+"\n\n"
                    + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                    + f"EI Question {turn_idx+1}:\n"
                    + metadata["augmented_questions"][rule_idx]
                )
            else:
                user_prompt = f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
        elif turn_idx == max_turns:
            user_prompt = (
                prompt_answer_originalQ
                + "Original Question:\n"
                + metadata["question"]
                + "\n\n"
                + "Please reason step by step inside <think>...</think>, and put your final answer within \\boxed{}."
            )
        else:
            user_prompt = (
                "Let's continue by exploring another related question. Remember, do NOT answer the original question yet.\n\n"
                + f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            )
    
    elif task == "anli":
        premise = metadata["premise"]
        hypothesis = metadata["hypothesis"]
        answer = metadata["label"]
        if answer == "entailment":
            answer = "True"
        elif answer == "contradiction":
            answer = "False"
        else:
            answer = "Neither"
        if turn_idx == 0:
            if provide_original_question_first:
                user_prompt = (
                    "Original Question:\nGiven that \""+premise+"\"\nQuestion: "+hypothesis+"\nTrue, False, or Neither?\n\n"
                    + "Let's begin by exploring a related question. Remember, do NOT answer the original question yet.\n\n"
                    + f"EI Question {turn_idx+1}:\n"
                    + metadata["augmented_questions"][rule_idx]
                )
            else:
                user_prompt = f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
        elif turn_idx == max_turns:
            user_prompt = (
                prompt_answer_originalQ
                + "Original Question:\n"
                + "Given that \""+premise+"\"\n"
                + "Question: "+hypothesis+"\nTrue, False, or Neither?\n\n"
                + "Please reason step by step inside <think>...</think>, and conclude with your final answer."
                + "Your response should end with \"The best answer is [the_answer_choice]\" where [the_answer_choice] is one of \"True\", \"False\", or \"Neither\"."
            )
        else:
            user_prompt = (
                "Let's continue by exploring another related question. Remember, do NOT answer the original question yet.\n\n"
                + f"EI Question {turn_idx+1}:\n{metadata['augmented_questions'][rule_idx]}"
            )
    
    else:
        raise ValueError(f"Task {task} is not supported for reverse augmentation")
        
    messages = []
    if system_prompt and turn_idx == 0:
        # Only add the system prompt for the first turn
        if support_system_prompt:
            messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": user_prompt})
        else:
            messages.append({"role": "user", "content": system_prompt + "\n\n" + user_prompt})
    else:
        messages.append({"role": "user", "content": user_prompt})

    return messages


def format_dataset_for_reverse_cot(dataset: Dataset,
                                   support_system_prompt: bool = True,
                                   system_prompt: str | None = None,
                                   answer_key: str = "answer",
                                   task: str = "strategy_qa",
                                   turn_idx: int = 0,
                                   max_turns: int = 4,
                                   total_rules: int = 10,
                                   rule_indices: list[int] = None,
                                   ) -> Dataset:
     
     if rule_indices is None:
         # If no rule indices, first map the dataset to select the rule indices
         print("Selecting rule indices...")
         dataset = dataset.map(lambda x: {
             **x,
             "rule_indices": select_rule_indices(x, max_turns, total_rules)
         }, num_proc=10)
     else:
         # If rule indices are provided, use them, set the rule_indices as a column in the dataset
         print("Using provided rule indices...")
         dataset = dataset.map(lambda x: {
             **x,
             "rule_indices": rule_indices
         }, num_proc=10)

     # Then map the dataset to format the prompt
     print("Formatting prompt...")
     if answer_key == "answer":
         # If answer_key is "answer", we don't need to explicitly set it again since it's already in **x
         return dataset.map(lambda x: {
             **x,  # Preserve all original keys including "answer"
             "prompt": format_prompt_for_reverse_cot(x, support_system_prompt, system_prompt,
                                                     task, turn_idx, max_turns,
                                                     total_rules, x["rule_indices"],
                                                     provide_original_question_first=False)
         }, num_proc=10)
     else:
         # If answer_key is different from "answer", we need to set it explicitly
         return dataset.map(lambda x: {
             **x,  # Preserve all original keys
             "prompt": format_prompt_for_reverse_cot(x, support_system_prompt, system_prompt,
                                                     task, turn_idx, max_turns,
                                                     total_rules, x["rule_indices"],
                                                     provide_original_question_first=False),
             "answer": x[answer_key]  # Set the answer from the specified key
         }, num_proc=10)


def format_sft_dataset_for_reverse_cot(dataset: Dataset,
                                       system_prompt: str | None = None,
                                       task: str = "strategy_qa",
                                       max_turns: int = -1,
                                       total_rules: int = 10,
                                       cot_response_key: str = "cot_solution",
                                       augmented_question_key: str = "augmented_questions",
                                       augmented_solutions_key: str = "augmented_solutions") -> Dataset:
    """
    Format the dataset for supervised fine-tuning (SFT) with reverse COT.
    """
    if max_turns == -1:
        max_turns = total_rules
    
    # Define a processing function to apply to each example
    def process_example(example):
        messages = []
        # Use a fixed seed based on example id or hash to ensure reproducibility
        example_seed = hash(str(example.get("id", str(example)))) % 10000 + 42
        rng = random.Random(example_seed)
        
        rule_indices = list(range(total_rules))
        rng.shuffle(rule_indices)
        
        for turn_idx in range(max_turns):
            # Format the prompt for the current turn
            messages.extend(format_prompt_for_reverse_cot(example, system_prompt, task, 
                                                          turn_idx=turn_idx,
                                                          max_turns=max_turns,
                                                          total_rules=total_rules,
                                                          rule_indices=rule_indices))
            messages.append({"role": "assistant", "content": example["metadata"][augmented_solutions_key][rule_indices[turn_idx]]})
        
        # Add the original question and answer at the last turn
        messages.extend(format_prompt_for_reverse_cot(example, system_prompt, task,
                                                      turn_idx=max_turns,
                                                      max_turns=max_turns,
                                                      total_rules=total_rules,
                                                      rule_indices=rule_indices))
        messages.append({"role": "assistant", "content": example["metadata"][cot_response_key]})
        
        return {"messages": messages}
    
    # Apply the processing function in parallel
    return dataset.map(
        process_example,
        num_proc=10,  # Use 10 parallel workers
    )


def get_datasets(
    data_config: dict,
    splits: Optional[List[str]] = None,
    configs: Optional[List[str]] = None,
    columns_to_keep: Optional[List[str]] = None,
    shuffle: bool = True,
    data_file: str = "json"
) -> DatasetDict:
    """
    Loads one or more datasets with varying training set proportions.

    Args:
        data_config (`DataArguments` or `dict`):
            Dataset configuration and split proportions.
        splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        configs (Optional[List[str]], *optional*, defaults to `None`):
            List of dataset config names. If given must be the same length as 'data_config' keys.
        columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
            Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
            and for cpt this should be (at least) the text column.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.

    Returns
        [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
    """
    if isinstance(data_config, dict):
        # Structure of the input is:
        #     dataset_mixer = {
        #             "dataset1": 0.5,
        #             "dataset1": 0.3,
        #             "dataset1": 0.2,
        #         }
        dataset_mixer = data_config["dataset_mixer"]
    else:
        raise ValueError(f"Data config {data_config} not recognized.")

    raw_datasets = mix_datasets(
        dataset_mixer,
        splits=splits,
        configs=configs,
        columns_to_keep=columns_to_keep,
        shuffle=shuffle,
        data_file=data_file
    )
    return raw_datasets


def mix_datasets(
    dataset_mixer: dict,
    splits: Optional[List[str]] = None,
    configs: Optional[List[str]] = None,
    columns_to_keep: Optional[List[str]] = None,
    shuffle=True,
    data_file: str = None
) -> DatasetDict:
    """
    Loads and mixes datasets according to proportions specified in `dataset_mixer`.

    Args:
        dataset_mixer (`dict`):
            Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
        splits (Optional[List[str]], *optional*, defaults to `None`):
            Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
        configs (Optional[List[str]], *optional*, defaults to `None`):
            List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
        columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
            Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
            and for cpt this should be (at least) the text column.
        shuffle (`bool`, *optional*, defaults to `True`):
            Whether to shuffle the training and testing/validation data.
    """
    splits = ["train", "test"] if splits is None else splits
    configs = [None] * len(dataset_mixer) if not configs else configs
    columns_to_keep = [] if columns_to_keep is None else columns_to_keep

    if configs is not None and len(configs) != len(dataset_mixer):
        raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")

    raw_datasets = DatasetDict()
    raw_train_datasets = []
    raw_val_datasets = []
    fracs = []
    for (ds, frac), ds_config in zip(dataset_mixer.items(), configs):
        fracs.append(frac)
        for split in splits:
            if data_file == "json":
                # Load dataset from local jsonl file
                print(f"Loading {split} split of {ds} from {os.path.join(ds, split)}")
                dataset = load_dataset("json", data_files=os.path.join(ds, f"{split}.jsonl"))
                dataset = dataset["train"]      # The above load_dataset returns a DatasetDict, so we need to access the train split
            else:
                try:
                    # Try first if dataset on a Hub repo
                    dataset = load_dataset(ds, ds_config, split=split)
                except DatasetGenerationError:
                    # If not, check local dataset
                    dataset = load_from_disk(os.path.join(ds, split))

            # Remove redundant columns to avoid schema conflicts on load
            dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
            if "train" in split:
                raw_train_datasets.append(dataset)
            elif "test" in split:
                raw_val_datasets.append(dataset)
            else:
                raise ValueError(f"Split type {split} not recognized as one of test or train.")

    if any(frac < 0 for frac in fracs):
        raise ValueError("Dataset fractions cannot be negative.")

    if len(raw_train_datasets) > 0:
        train_subsets = []
        for dataset, frac in zip(raw_train_datasets, fracs):
            train_subset = dataset.select(range(int(frac * len(dataset))))
            train_subsets.append(train_subset)
        if shuffle:
            raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
        else:
            raw_datasets["train"] = concatenate_datasets(train_subsets)
    # No subsampling for test datasets to enable fair comparison across models
    if len(raw_val_datasets) > 0:
        if shuffle:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
        else:
            raw_datasets["test"] = concatenate_datasets(raw_val_datasets)

    if len(raw_datasets) == 0:
        raise ValueError(
            f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted."
        )

    return raw_datasets