import random
from datasets import load_dataset
from b1.eval.gsm8k import GSM8KDataset

# from b1.data_utils import SYSTEM_PROMPT, SUDOKU_SYSTEM_PROMPT

# Baseline
MATH500_SYSTEM_PROMPT = """You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}.
Respond in the following format:
<reasoning>
Your reasoning here
</reasoning>
<answer>
\\boxed{...}
</answer>"
"""


# Dynamic generation: end of block marker \\block, avoid starting line with \\block
# Step 1, ... \\block
# Step 2, ... \\block
# ...
# Step n, ... \\block
MATH500_SYSTEM_PROMPT = """You are a math expert. You will be given a question to solve. You should provide clear, logical reasoning step by step.
Append the tag \\block directly to the end of the last sentence of each reasoning step without starting a new line.
After reasoning, wrap only the final answer in a \\boxed{} tag. Respond exactly in the following format:
<reasoning>
Step 1, ... \\block
Step 2, ... \\block
...
Step n, ... \\block
</reasoning>
<answer>
\\boxed{...}
</answer>"
"""


class MATH500Dataset(GSM8KDataset):
    def __init__(
        self,
        tokenizer,
        num_examples=0,
        add_reasoning=True,
        system_prompt=MATH500_SYSTEM_PROMPT,
        subsample=-1,
    ):
        super().__init__(
            tokenizer, num_examples, add_reasoning, system_prompt, subsample
        )

    def load_test_dataset(self):
        self.dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")

    def load_few_shot_examples(self):
        train_data = load_dataset(
            "EleutherAI/hendrycks_math", ("algebra"), split="train"
        )
        few_shot_examples = []
        samples = random.sample(range(len(train_data)), self.num_examples)
        for example in samples:
            few_shot_examples.append(
                {
                    "question": train_data[example]["problem"],
                    "answer": train_data[example]["solution"],
                }
            )
        return few_shot_examples

    def __getitem__(self, idx):
        question = self.dataset[self.subsample[idx].item()]["problem"]
        answer = self.dataset[self.subsample[idx].item()]["answer"]
        prompt = self.create_prompt(question)
        return prompt, question, answer
