import re


def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_xml_reasoning(text: str) -> str:
    reasoning = text.split("<think>")[-1]
    reasoning = reasoning.split("</think>")[0]
    return reasoning.strip()

def correctness_reward_func(completions, label, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, label)]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [0.5 if match else 0.0 for match in matches]

def contrastive_reasoning_reward_func(completions, processor=None, vllm_client=None, **kwargs) -> list[float]:
    prompt = """Evaluate the student's reasoning for classifying an image based on the following criteria and provide a score from 1 to 10:
- Identification of similarities between the given image and the dataset images of the chosen class.
- Identification of differences between the given image and the dataset images not belonging to the chosen class.
- Identification of relevant details in both the given image and the dataset images.
- The student claims the class only after clearly highlighting the similarities and differences.

The student's reasoning:
{reasoning}

Output only the score (1-10) without formattings and explanations."""
    completion_reasonings = [extract_xml_reasoning(completion[0]["content"]) for completion in completions]
    prompts_text = []
    for reasoning in completion_reasonings:
        prompts_text.append(processor.apply_chat_template([{"role": "user", "content": prompt.format(reasoning=reasoning, tokenize=False, add_generation_prompt=True)}]))
    scores = []
    for score_str in vllm_client.generate(prompts_text, max_tokens=10, temperature=0.0):
        num_str = re.findall(r"\d+", processor.decode(score_str, skip_special_tokens=True))
        if not num_str:
            scores.append(0.)
        else:
            scores.append((float(num_str[0]) - 1) / 9)
    return scores
