import re
import torch
from datasets import load_dataset, Dataset, concatenate_datasets
import random
import math
import os
import numpy as np
import sys


PATH_TO_REPO = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))





EXAMPLARS = [
    {
        "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
        "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.",
        "short_answer": "6"
    },
    {
        "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
        "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
        "short_answer": "5"
    },
    {
        "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
        "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
        "short_answer": "39"
    },
    {
        "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
        "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
        "short_answer": "8"
    },
    {
        "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
        "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
        "short_answer": "9"
    },
    {
        "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
        "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
        "short_answer": "29"
    },
    {
        "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
        "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
        "short_answer": "33"
    },
    {
        "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
        "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
        "short_answer": "8"
    },
    {
        "question": "What is the largest single-digit prime number?",
        "cot_answer": "9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
        "short_answer": "7"
    }
]

EXAMPLARS_DEEPSCALER = [
    {
        "question": "Alicia earns 20 dollars per hour, of which $1.45\%$ is deducted to pay local taxes. How many cents per hour of Alicia's wages are used to pay local taxes?",
        "cot_answer": "1. **Convert Alicia's hourly wage to cents**: Alicia earns $20$ dollars per hour. Since there are $100$ cents in a dollar, her hourly wage in cents is: \[ 20 \text{ dollars} \times 100 \text{ cents/dollar} = 2000 \text{ cents} \]  2. **Calculate the tax deduction in cents**: The local tax rate is $1.45\%$. To find the amount deducted for taxes in cents, we calculate $1.45\%$ of $2000$ cents: \[ 1.45\% \text{ of } 2000 \text{ cents} = 0.0145 \times 2000 = 29 \text{ cents} \]  3. **Conclusion**: Therefore, $29$ cents per hour of Alicia's wages are used to pay local taxes.  \[ \boxed{\mathrm{(E)}\ 29} \]",
        "short_answer": "29"
    },
    {
        "question": "A number $x$ is $2$ more than the product of its reciprocal and its additive inverse. In which interval does the number lie?",
        "cot_answer": "Let's analyze the given problem step by step.  1. **Understanding the terms:** - The **reciprocal** of a number $x$ is $\frac{1}{x}$. - The **additive inverse** of a number $x$ is $-x$.  2. **Setting up the equation:** The problem states that $x$ is $2$ more than the product of its reciprocal and its additive inverse. This can be mathematically expressed as: \[ x = \left(\frac{1}{x}\right) \cdot (-x) + 2 \]  3. **Simplifying the equation:** \[ x = \left(\frac{1}{x}\right) \cdot (-x) + 2 = -1 + 2 = 1 \] However, there is a mistake in the simplification above. The correct simplification should be: \[ x = \left(\frac{1}{x}\right) \cdot (-x) + 2 = -1 + 2 \] This simplification is incorrect because the product $\left(\frac{1}{x}\right) \cdot (-x)$ simplifies to $-1$, not $-1 + 2$. The correct equation should be: \[ x = -1 + 2 = 1 \]  4. **Solving the equation:** The equation simplifies directly to $x = 1$. This means that $x$ equals $1$, which is a specific value, not an interval.  5. **Determining the interval:** Since $x = 1$, we need to find which interval option includes the number $1$. Looking at the options: - $\textbf{(A)}\ -4\le x\le -2$ - $\textbf{(B)}\ -2 < x\le 0$ - $\textbf{(C)}\ 0 < x \le 2$ - $\textbf{(D)}\ 2 < x\le 4$ - $\textbf{(E)}\ 4 < x\le 6$  The number $1$ falls within the interval $0 < x \le 2$.  Therefore, the correct answer is $\boxed{\textbf{(C)}\ 0 < x \le 2}$.",
        "short_answer": "0 < x \le 2"
    },
    {
        "question": "When Cheenu was a boy, he could run $15$ miles in $3$ hours and $30$ minutes. As an old man, he can now walk $10$ miles in $4$ hours. How many minutes longer does it take for him to walk a mile now compared to when he was a boy?",
        "cot_answer": "1. **Calculate the time Cheenu took per mile as a boy:** - Cheenu ran 15 miles in 3 hours and 30 minutes. - Convert hours to minutes: \(3 \text{ hours} = 3 \times 60 = 180 \text{ minutes}\). - Add the extra 30 minutes: \(180 \text{ minutes} + 30 \text{ minutes} = 210 \text{ minutes}\). - Calculate the time per mile: \(\frac{210 \text{ minutes}}{15 \text{ miles}} = 14 \text{ minutes per mile}\).  2. **Calculate the time Cheenu takes per mile now as an old man:** - Cheenu walks 10 miles in 4 hours. - Convert hours to minutes: \(4 \text{ hours} = 4 \times 60 = 240 \text{ minutes}\). - Calculate the time per mile: \(\frac{240 \text{ minutes}}{10 \text{ miles}} = 24 \text{ minutes per mile}\).  3. **Determine the difference in time per mile:** - Difference: \(24 \text{ minutes per mile} - 14 \text{ minutes per mile} = 10 \text{ minutes per mile}\).  4. **Conclusion:** - It takes Cheenu 10 minutes longer to walk a mile now compared to when he was a boy.  Thus, the answer is \(\boxed{\textbf{(B)}\ 10}\).",
        "short_answer": "10"
    },
    {
        "question": "A straight line joins the points $(-1,1)$ and $(3,9)$. Its $x$-intercept is:",
        "cot_answer": "1. **Identify the coordinates of the given points and calculate the slope ($m$) of the line.** The points given are $(-1,1)$ and $(3,9)$. The slope formula is: \[ m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1} = \frac{9 - 1}{3 - (-1)} = \frac{8}{4} = 2. \]  2. **Write the equation of the line using the point-slope form.** Using the point $(-1,1)$ and the slope $m = 2$, the point-slope form of the line equation is: \[ y - y_1 = m(x - x_1) \implies y - 1 = 2(x + 1). \] Simplifying this, we get: \[ y - 1 = 2x + 2 \implies y = 2x + 3. \]  3. **Find the $x$-intercept of the line.** The $x$-intercept occurs where $y = 0$. Substituting $0$ for $y$ in the line equation: \[ 0 = 2x + 3. \] Solving for $x$, we have: \[ 2x = -3 \implies x = -\frac{3}{2}. \]  4. **Conclude with the correct answer.** The $x$-intercept of the line is $-\frac{3}{2}$. Therefore, the correct choice is: \[ \boxed{\textbf{(A)}\ -\frac{3}{2}} \]",
        "short_answer": "-\frac{3}{2}"
    },
    {
        "question": "A rising number, such as $34689$, is a positive integer each digit of which is larger than each of the digits to its left. There are $\binom{9}{5} = 126$ five-digit rising numbers. When these numbers are arranged from smallest to largest, the $97^{\text{th}}$ number in the list does not contain the digit",
        "cot_answer": "1. **Understanding the Problem**: We need to find the $97^{\text{th}}$ five-digit rising number and identify which digit from the given options it does not contain.  2. **Counting Rising Numbers Starting with '1'**: - A five-digit rising number starting with '1' can be formed by choosing 4 more digits from 2 to 9. - The number of ways to choose 4 digits from 8 remaining digits (2 through 9) is $\binom{8}{4} = 70$. - Therefore, there are 70 five-digit rising numbers starting with '1'.  3. **Finding the First Number Not Starting with '1'**: - The $71^{\text{st}}$ number is the smallest five-digit rising number not starting with '1'. - The smallest digit available after '1' is '2', so the $71^{\text{st}}$ number starts with '2'. - The smallest five-digit rising number starting with '2' is $23456$.  4. **Counting Rising Numbers Starting with '23'**: - A five-digit rising number starting with '23' can be formed by choosing 3 more digits from 4 to 9. - The number of ways to choose 3 digits from 6 remaining digits (4 through 9) is $\binom{6}{3} = 20$. - Therefore, there are 20 five-digit rising numbers starting with '23'. - The $91^{\text{st}}$ number is the first number starting with '23', which is $23456$.  5. **Identifying the $97^{\text{th}}$ Number**: - We need to find the $97^{\text{th}}$ number, which is 6 numbers after the $91^{\text{st}}$ number. - The sequence of numbers starting from $23456$ and adding one more digit each time from the available set {4, 5, 6, 7, 8, 9} while maintaining the rising property gives us: - $23457, 23458, 23459, 23467, 23468, 23469, 23478, 23479, 23489, 23567, 23568, 23569, 23578, 23579, 23589, 23678, 23679, 23689, 23789, 24567, 24568, 24569, 24578, 24579, 24589, 24678$. - The $97^{\text{th}}$ number in this sequence is $24678$.  6. **Checking Which Digit is Missing**: - The number $24678$ contains the digits 2, 4, 6, 7, and 8. - The digit missing from the options given is '5'.  Therefore, the $97^{\text{th}}$ rising number does not contain the digit $\boxed{\textbf{(B)} \ 5}$.",
        "short_answer": "5"
    }

]
# Load and prep dataset

SYSTEM_PROMPT2 = """
Respond in the following format:

<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}."
SYSTEM_PROMPT_KK = """Your task is to solve a logical reasoning problem. You are given set of statements from which
you must logically deduce the identity of a set of characters.
You must infer the identity of each character. At the end of your answer, you must clearly
state the identity of each character by following the format:
<answer>
(1) ...
(2) ...
(3) ...
</answer>
### Question: {question}
### Answer:"""
SYSTEM_PROMPT_WITH_LABEL = "Given the answer is {}, please reason step by step, and put your final answer within"

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

XML_COT_FORMAT_DEEPSCALER = """\
{reasoning}
\\boxed{{{answer}}}
"""

def create_sft_data(mode):
    # Load dataset from the hub
    # dataset = load_dataset("openai/gsm8k", "main")["train"]
    if mode == "uniform":
        data = load_dataset("parquet", data_files=f"{PATH_TO_REPO}/UncertainReasoning/data/deepscaler_uniform_train.parquet")["train"]
    
    print(data[0])

def make_cot_prompt(x, num_shots=0):
    eids = list(range(len(EXAMPLARS)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    for  i in range(num_shots):
        example = EXAMPLARS[eids[i]]
        prompts.append({'role': 'user', 'content': example["problem"]})
        prompts.append({'role': 'assistant', 'content': XML_COT_FORMAT_DEEPSCALER.format(
               reasoning=example["cot_answer"],
               answer=example["short_answer"]
            )})
    prompts.append({'role': 'user', 'content': x['problem']})
    return prompts

def make_cot_prompt_deepscaler(args, x, num_shots=0):
    eids = list(range(len(EXAMPLARS_DEEPSCALER)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    for i in range(args.num_shots):
        example = EXAMPLARS_DEEPSCALER[eids[i]]
        prompts.append({'role': 'user', 'content': example["question"]})
        prompts.append({'role': 'assistant', 'content': XML_COT_FORMAT_DEEPSCALER.format(
               reasoning=example["cot_answer"],
               answer=example["short_answer"]
            )})
    final_prompt = x['prompt'][0]['content'].strip() + "\n" + SYSTEM_PROMPT
    prompts.append({'role': 'user', 'content': final_prompt})
    return prompts

def make_cot_prompt_limr(args, x, num_shots=0):
    eids = list(range(len(EXAMPLARS_DEEPSCALER)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    for i in range(args.num_shots):
        example = EXAMPLARS_DEEPSCALER[eids[i]]
        prompts.append({'role': 'user', 'content': example["question"]})
        prompts.append({'role': 'assistant', 'content': XML_COT_FORMAT_DEEPSCALER.format(
               reasoning=example["cot_answer"],
               answer=example["short_answer"]
            )})
    final_prompt = x['problem'].strip() + "\n" + SYSTEM_PROMPT
    prompts.append({'role': 'user', 'content': final_prompt})
    return prompts

def make_cot_prompt_kk(args, x, num_shots=0):
    eids = list(range(len(EXAMPLARS_DEEPSCALER)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    for i in range(args.num_shots):
        example = EXAMPLARS_DEEPSCALER[eids[i]]
        prompts.append({'role': 'user', 'content': example["question"]})
        prompts.append({'role': 'assistant', 'content': XML_COT_FORMAT_DEEPSCALER.format(
               reasoning=example["cot_answer"],
               answer=example["short_answer"]
            )})
    final_prompt = SYSTEM_PROMPT_KK.format(question=x['quiz'])
    prompts.append({'role': 'user', 'content': final_prompt})
    return prompts

def make_cot_prompt_livecodebench(args, x, num_shots=0):
    eids = list(range(len(EXAMPLARS_DEEPSCALER)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    final_prompt = SYSTEM_PROMPT_KK.format(question=x['question_content'])
    prompts.append({'role': 'user', 'content': final_prompt})
    return prompts


def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_xml_reasoning(text: str) -> str:
    answer = text.split("<reasoning>")[-1]
    answer = answer.split("</reasoning>")[0]
    return answer.strip()

def extract_hash_answer(text: str):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train", num_shots=0, real=False) -> Dataset:

    data = load_dataset("lime-nlp/GSM8K_Difficulty", "Difficulty Score")['train']
    # data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    if split == "test":
        data = data.map(lambda x: { # type: ignore
            'prompt': make_cot_prompt(x, num_shots=num_shots),
            'answer': extract_hash_answer(x['ground_truth'])
        }) # type: ignore
    else:
        data = data.map(lambda x: { # type: ignore
            'prompt': make_cot_prompt(x, num_shots=num_shots),
            'answer': extract_hash_answer(x['ground_truth'])
        })

    return data # type: ignore

def process_deepscaler_data(example):
    import re
    # Match and remove the instruction at the end
    pattern = r"Let's think step by step and output the final answer within \\boxed\{\}\.?\s*$"
    original_prompt = example['prompt'][0]['content']
    cleaned_prompt = re.sub(pattern, "", original_prompt).rstrip()
    example['prompt'][0]['content'] = cleaned_prompt
    return example

def process_limr_data(example):
    import re
    # Match and remove the instruction at the end
    pattern = r"Let's think step by step and output the final answer within \\boxed\{\}\.?\s*$"
    original_prompt = example['problem']
    cleaned_prompt = re.sub(pattern, "", original_prompt).rstrip()
    example['problem'] = cleaned_prompt
    return example

def process_kk_data(example):
    import re
    # Match and remove the instruction at the end
    pattern = r"Let's think step by step and output the final answer within \\boxed\{\}\.?\s*$"
    original_prompt = example['quiz']
    cleaned_prompt = re.sub(pattern, "", original_prompt).rstrip()
    example['quiz'] = cleaned_prompt
    # init "extra_info" field to empty 
    example['extra_info'] = {}
    return example

def process_livecodebench_data(example):
    import re
    # Match and remove the instruction at the end
    pattern = r"Let's think step by step and output the final answer within \\boxed\{\}\.?\s*$"
    original_prompt = example['question_content']
    cleaned_prompt = re.sub(pattern, "", original_prompt).rstrip()
    example['question_content'] = cleaned_prompt
    # init "extra_info" field to empty 
    example['extra_info'] = {}
    return example



def get_deepscaler_questions(args, split="train", mode="uniform", num_shots=0) -> Dataset:
    data = load_dataset('parquet', data_files=f"{args.dataset_path}")[split] # type: ignore    
    # map all data['prompt][0]['content'] to remove the instruction
    data = data.map(process_deepscaler_data)
    data = data.map(lambda x: { # type: ignore
        'prompt': make_cot_prompt_deepscaler(args, x, num_shots=args.num_shots),
        'answer': x['reward_model']['ground_truth']
    }) # type: ignore
    return data # type: ignore

def get_limr_questions(args, split="train", mode="uniform", num_shots=0) -> Dataset:
    data = load_dataset("nlile/hendrycks-MATH-benchmark")[split]    # map all data['prompt][0]['content'] to remove the instruction
    data = data.map(process_limr_data)
    data = data.map(lambda x: { # type: ignore
        'prompt': make_cot_prompt_limr(args, x, num_shots=args.num_shots),
        'answer': x['answer']
    }) # type: ignore
    return data # type: ignore

def get_kk_questions(args, split="train", num_shots=0) -> Dataset:
    data = load_dataset("K-and-K/knights-and-knaves", "train")
    data = concatenate_datasets([data[k] for k in data.keys()])
    data = data.map(process_kk_data)
    data = data.map(lambda x: { # type: ignore
        'prompt': make_cot_prompt_kk(args, x, num_shots=args.num_shots),
        'answer': x['solution_text_format']
    }) # type: ignore
    return data

def get_livecodebench_questions(args, split="train", num_shots=0) -> Dataset:
    data = load_dataset("livecodebench/code_generation_lite", version_tag="release_v5", trust_remote_code=True)["test"]
    print(data[0].keys())
    quit()
    data = concatenate_datasets([data[k] for k in data.keys()])
    data = data.map(process_kk_data)
    data = data.map(lambda x: { # type: ignore
        'prompt': make_cot_prompt_livecodebench(args, x, num_shots=args.num_shots),
        'answer': x['solution_text_format']
    }) # type: ignore
    return data


def process_deepscaler_data_distill(example):
    import re
    # Match and remove the instruction at the end
    pattern = r"Let's think step by step and output the final answer within \\boxed\{\}\.?\s*$"
    original_prompt = example['prompt'][0]['content']
    cleaned_prompt = re.sub(pattern, "", original_prompt).rstrip()
    example['prompt'][0]['content'] = cleaned_prompt
    return example

def get_deepscaler_questions_distill(args, split="train", mode="uniform", num_shots=0) -> Dataset:
    data = load_dataset('parquet', data_files=f"{args.dataset_path}")[split] # type: ignore    
    # map all data['prompt][0]['content'] to remove the instruction
    data = data.map(process_deepscaler_data_distill)
    data = data.map(lambda x: { # type: ignore
        'prompt': make_cot_prompt_deepscaler_distill(args, x, x['reward_model']['ground_truth'], num_shots=args.num_shots),
        'answer': x['reward_model']['ground_truth']
    }) # type: ignore
    return data # type: ignore

def make_cot_prompt_deepscaler_distill(args, x, label, num_shots=0):
    eids = list(range(len(EXAMPLARS_DEEPSCALER)))
    random.shuffle(eids)
    prompts = []
    prompts.append({'role': 'system', 'content': SYSTEM_PROMPT})
    for i in range(args.num_shots):
        example = EXAMPLARS_DEEPSCALER[eids[i]]
        prompts.append({'role': 'user', 'content': example["question"]})
        prompts.append({'role': 'assistant', 'content': XML_COT_FORMAT_DEEPSCALER.format(
               reasoning=example["cot_answer"],
               answer=example["short_answer"]
            )})
    final_prompt = x['prompt'][0]['content'].strip() + "\n" + SYSTEM_PROMPT_WITH_LABEL.format(label) + " \\boxed{}."
    prompts.append({'role': 'user', 'content': final_prompt})
    return prompts

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print(q)
    print(extracted_responses)
    print(answer)
    quit()
    # print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def extract_boxed_answer(text: str) -> str:
    """Extracts the full content of \boxed{...}, including nested braces."""
    start = text.find(r'\boxed{')
    if start == -1:
        return ""
    
    i = start + len(r'\boxed{')
    brace_count = 1
    content = []

    while i < len(text) and brace_count > 0:
        if text[i] == '{':
            brace_count += 1
        elif text[i] == '}':
            brace_count -= 1
        
        if brace_count > 0:
            content.append(text[i])
        i += 1

    return ''.join(content).strip()


def correctness_reward_bypass_template_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    rewards = []
    extracted_responses = []
    for r in responses:

        all_numbers = re.findall('\d+[.,]?\d*\s', r)
        if len(all_numbers) > 0:
            extracted_responses.append(all_numbers[-1].replace('.', '').replace(',', '').replace('\n', ''))
        else:
            extracted_responses.append(f"-1uiekc7") # no reward
            print(f"all_numbers = {all_numbers}, check this response : {r}")
         
    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

    
def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

def get_cosine_scaled_reward(
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    def cosine_scaled_reward(prompts, completions, answer, **kwargs):
        """Reward function that scales based on completion length using a cosine schedule.

        Shorter correct solutions are rewarded more than longer ones.
        Longer incorrect solutions are penalized less than shorter ones.

        Args:
            completions: List of model completions
            solution: List of ground truth solutions

        This function is parameterized by the following arguments:
            min_value_wrong: Minimum reward for wrong answers
            max_value_wrong: Maximum reward for wrong answers
            min_value_correct: Minimum reward for correct answers
            max_value_correct: Maximum reward for correct answers
            max_len: Maximum length for scaling
        """
        responses = [completion[0]['content'] for completion in completions]
        q = prompts[0][-1]['content']
        # extracted_responses = [extract_xml_answer(r) for r in responses]
        extracted_responses = [extract_boxed_answer(r) for r in responses]
        rewards = []

        for content, exanswer, sol in zip(responses, extracted_responses, answer):
            
            is_correct = False
            if exanswer==sol:
                is_correct = True
            gen_len = len(content)

            # Apply cosine scaling based on length
            progress = gen_len / max_len
            cosine = math.cos(progress * math.pi)

            if is_correct:
                min_value = min_value_correct
                max_value = max_value_correct
            else:
                # Swap min/max for incorrect answers
                min_value = max_value_wrong
                max_value = min_value_wrong

            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))

        return rewards

    return cosine_scaled_reward

# Extract final answers for Deepscaler dataset - https://github.com/limenlp/verl/blob/main/verl/utils/reward_score/math.py
def compute_score(solution_str, ground_truth) -> float:
    retval = 0.
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if is_equiv(answer, ground_truth):
                retval = 1.
    except Exception as e:
        print(e)

    return retval


# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = strip_string(str1)
        ss2 = strip_string(str2)
        if verbose:
            print(ss1, ss2)
        return ss1 == ss2
    except Exception:
        return str1 == str2


def remove_boxed(s):
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[:len(left)] == left
        return s[len(left):]

    left = "\\boxed{"

    assert s[:len(left)] == left
    assert s[-1] == "}"

    return s[len(left):-1]


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval


def fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except AssertionError:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string


def fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except AssertionError:
        return string


def remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string


def fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string


def strip_string(string):
    # linebreaks
    string = string.replace("\n", "")

    # remove inverse spaces
    string = string.replace("\\!", "")

    # replace \\ with \
    string = string.replace("\\\\", "\\")

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")

    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")

    # remove units (on the right)
    string = remove_right_units(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")  # noqa: W605

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = fix_a_slash_b(string)

    return string


def get_qwen_embeddings(prompts, tokenizer, model, batch_size=8):
    embs = []
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i : i + batch_size]
        enc   = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to("cuda")
        # ask the model to return hidden states
        with torch.no_grad():
            out = model(**enc, output_hidden_states=True)
        # out.hidden_states is a tuple (layer0, layer1, …, layerL)
        last_hidden = out.hidden_states[-1]         # shape: (B, T, D)
        mask        = enc["attention_mask"].unsqueeze(-1)  # (B, T, 1)
        # mean-pool over tokens
        summed = (last_hidden * mask).sum(1)        # (B, D)
        counts = mask.sum(1)                        # (B, 1)
        pooled = summed / counts                    # (B, D)
        embs.append(pooled.cpu().numpy())
    return np.vstack(embs)


def add_difficulty(train_dataset):
    solved_percentage = train_dataset['solved_percentage']
    # add difficulties as train_dataset["extra_info"]["difficulty"]
    train_dataset = train_dataset.add_column("extra_info", [{"difficulty": 100 - d} for d in solved_percentage])
    return train_dataset