# Rewards

import re
from awdpo.utils import *

def reasoning_reward(prompts, completions, answer, **kwargs) -> list:
    rewards = []
    transition_words = ["first", "next", "then", "because", "wait", "aha", "therefore", "finally", "in summary"]
    pattern = r"<\s*reasoning\s*>(.*?)<\s*/\s*reasoning\s*>"
    for comp in completions:
        match = re.search(pattern, comp, re.DOTALL | re.IGNORECASE)
        if match:
            reasoning_text = match.group(1).strip()
            words = reasoning_text.split()
            reward = 0.0
            # base reward if at least 25 words in between <thinking> </thinking> tags
            if len(words) >= 25:
                reward += 0.25
            lower_text = reasoning_text.lower()
            # transition words reward (case-insensitive)
            transition_count = sum(1 for word in transition_words if word in lower_text)
            if transition_count > 0:
                reward += 0.5
            # bonus reward if there are at least 30 words
            if len(words) >= 50:
                reward += 0.35
            rewards.append(reward)
        else:
            rewards.append(0.0)
    return rewards

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def accuracy_reward(prompts, completions, answer, num_generated_samples_to_view=False, q_num=None, **kwargs) -> list:
    q = prompts[0]
    user_question = get_user_prompt(q)
    assistant_responses = [get_assistant_response(r) for r in completions]
    extracted_responses = [extract_xml_answer(get_assistant_response(r)) for r in completions]
    
    if num_generated_samples_to_view:
        print(f"{'='*15} Sample {q_num} {'='*15}\nQuestion:\n{user_question}\n\nAnswer:\n{extract_hash_answer(answer[0])}\n\nResponse:\n{assistant_responses[0]}\n\nExtracted:\n{extracted_responses[0]}\n{'='*18} End {'='*18}\n")
    return [2.0 if r.strip() == extract_hash_answer(a).strip() else 0.0 for r, a in zip(extracted_responses, answer)]

def soft_format_reward(completions, **kwargs) -> list:
    """
    Gives 0.5 reward if the completion contains both reasoning and answer tags
    in the correct order, regardless of exact formatting.
    """
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    return [0.5 if re.search(pattern, comp, re.DOTALL) else 0.0 for comp in completions]

def strict_format_reward(completions, **kwargs) -> list:
    # Use a pattern that matches reasoning/answer tags within assistant messages
    pattern = r"<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>"
    return [1.0 if re.search(pattern, comp) else 0.0 for comp in completions]

def xmlcount_reward(prompts, completions, answer, **kwargs) -> list:
    return [count_xml(comp) * 0.5 for comp in completions]

def int_reward(completions, **kwargs) -> list:
    return [0.5 if get_assistant_response(comp).strip().isdigit() else 0.0 for comp in completions]

def proper_termination_reward(completions, **kwargs) -> list:
    """
    Gives +1.0 reward if the completion properly terminates with <|im_end|>
    as the last string in the response.
    """
    rewards = []
    for comp in completions:
        comp = comp.strip()
        if comp.endswith("<|im_end|>") and not re.search(r"<\|im_end\|>.+", comp):
            rewards.append(0.5)
        else:
            rewards.append(0.0)
    
    return rewards

def clean_answer_termination_reward(completions, **kwargs) -> list:
    """
    Gives +1.0 reward if the completion ends its response right after
    finishing the </answer> tag (allowing for whitespace and <|im_end|>).
    """
    pattern = r"</answer>\s*(?:<\|im_end\|>)?$"
    
    rewards = []
    for comp in completions:
        if re.search(pattern, comp.strip()):
            rewards.append(0.5)
        else:
            rewards.append(0.0)
    
    return rewards

def coherency_reward(completions, min_length=4, threshold=3, max_reward=1.0, **kwargs) -> list:
    """
    Rewards responses that avoid excessive repetition of characters or words.
    """
    rewards = []
    
    for comp in completions:
        response = get_assistant_response(comp)
        reward_value = max_reward
        
        # Check for repetitive punctuation patterns first
        punct_pattern = r'([,.!?;:\'"-])\1{4,}'  # 10+ consecutive identical punctuation
        if re.search(punct_pattern, response):
            reward_value = 0.0  # No reward for excessive punctuation repetition
        else:
            # Continue with word/phrase repetition check
            words = response.lower().split()
            
            if len(words) >= min_length * threshold:
                max_repetitions = 0
                for seq_length in range(min_length, min(10, len(words) // 3)):
                    for i in range(len(words) - seq_length):
                        sequence = ' '.join(words[i:i+seq_length])
                        count = 0
                        pos = 0
                        while True:
                            pos = response.lower().find(sequence, pos)
                            if pos == -1:
                                break
                            count += 1
                            pos += len(sequence)
                        
                        max_repetitions = max(max_repetitions, count)
                
                if max_repetitions > threshold:
                    reward_value = max(0.0, max_reward * (1.0 - (max_repetitions - threshold) / 10.0))
        
        rewards.append(reward_value)
    
    return rewards