"""Reward functions for GRPO training."""

import json
import math
import re
from typing import Dict
import difflib

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from loguru import logger
import numpy as np
from open_r1.reward_qa import reward_qa
import traceback

from .utils import is_e2b_available


if is_e2b_available():
    from dotenv import load_dotenv
    from e2b_code_interpreter import Sandbox

    load_dotenv()


def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            try:
                reward = float(verify(answer_parsed, gold_parsed))
            except Exception as e:
                print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
                reward = 0.0
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards


def match_metric_name(metric: str, sentence: str, strict: bool = True) -> bool:
    pattern = r'[^\u4e00-\u9fa5a-zA-Z0-9]'
    sentence = re.sub(pattern, '', sentence).lower()
    metric = re.sub(pattern, '', metric).lower()
    if strict:
        result = (metric == sentence)
        # if not result:
        #     return difflib.SequenceMatcher(None, metric, sentence).ratio() > 0.9
        return result
    return metric in sentence

def try_fix_json(json_string: str, special_words=['question', 'answer', 'success', 'reference', ',', ':', '\n', '}', '{']):
    # Fix unmatched quots
    quotes_indices = [m.start() for m in re.finditer(r'"', json_string)]
    fixed_json = list(json_string)
    for i in quotes_indices:
        for special in special_words:
            if json_string[i + 1:].startswith(special) or json_string[:i].endswith(': '):
                break
        else:
            fixed_json[i] = r'\"'

    # Fix special words
    result = ''.join(fixed_json)
    result = result.replace('True', 'true').replace('False', 'false')

    # Fix delimeters
    result = re.sub(r'"\s*\n\s*"', '",\n"', result)

    return result

def escape_newlines_in_quotes(json_string):
    matches = list(re.finditer(r'(?<!\\)"([^"\\]*(?:\\.[^"\\]*)*)"', json_string, re.DOTALL))
    fixed_json = []
    last_end = 0
    
    for match in matches:
        start, end = match.span()
        text_between_quotes = json_string[start:end]
        escaped_text = text_between_quotes.replace('\n', '\\n')
        fixed_json.append(json_string[last_end:start])
        fixed_json.append(escaped_text)
        last_end = end
    
    fixed_json.append(json_string[last_end:])
    
    return ''.join(fixed_json)

def parse_llm_json(json_string: str, special_words=['question', 'answer', 'success', 'reference', ',', ':', '\n', '}', '{']):
    json_string = json_string.replace('```json', '').replace('```', '')
    try:
        json.loads(json_string)
    except Exception as err:
        json_string = try_fix_json(json_string, special_words)
        json_string = escape_newlines_in_quotes(json_string)
    
    return json.loads(json_string)

def rca_accuracy_reward(completions, solution, metrics, groups, **kwargs):
    """RCA Reward function that uses MRR to evaluate the rank of the ground truth root cause."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    feedback_prompts = []
    
    # print(f"[DEBUG RCA_ACC] {len(completions)=}, {len(solution)=}, {len(metrics)=}, {len(groups)=}, {type(metrics[0])=}, {type(groups[0])=}")
    idx = 0
    for content, sol, metric, group in zip(contents, solution, metrics, groups):
        # print(f"[DEBUG RCA_ACC]")
        # Try parse content
        try:
            # Extract items between <answer> and </answer>
            if '<answer>' in content and '</answer>' in content:
                content = content.split('<answer>')[1].split('</answer>')[0]
            content = content.strip().replace('```json', '').replace('```', '')
            if '{' in content and '}' in content:
                content = content[content.rfind('{'):content.rfind('}') + 1]
            content = parse_llm_json(content, special_words=['metric', 'component', 'conclusion', ',', ':', '\n', '}', '{', 'upstream', 'type', 'description'])
            sol = json.loads(sol)
        except Exception as err:
            rewards.append(0.0)
            feedback_prompts.append([])
            continue
        
        cur_feedbacks = []

        # Calculate MRR reward
        try:
            if sol.get("level", "metric") == "metric":
                # Get the group of ground truth root cause (first element in solution's rank_list)
                # if 'root_cause' not in sol:
                #     gt_root_causes = [sol["rank_list"][0]["metric"]]
                # else:
                #     gt_root_causes = sol["root_cause"]
                metric_to_group = dict((m, g) for m, g in zip(metric, group))
                gt_root_cause_group = metric_to_group[sol["rank_list"][0]["metric"]]

                # Get rank of group
                group_rank = []
                for item in content["rank_list"]:
                    # Check if the metric in root_cause list
                    flag = False
                    if "root_cause" in sol:
                        for root_cause_metric in sol["root_cause"]:
                            if match_metric_name(item["metric"], root_cause_metric):
                                cur_group = gt_root_cause_group
                                # print(f"[DEBUG RCA_ACC] {idx=} {item['metric']} -> {cur_group} (root cause group)")
                                if cur_group not in group_rank:
                                    group_rank.append(cur_group)
                                flag = True
                                break
                    if flag:
                        continue
                    for m in metric:
                        if match_metric_name(item["metric"], m):
                            cur_group = metric_to_group[m]
                            # print(f"[DEBUG RCA_ACC] {idx=} {item['metric']} -> {cur_group}")
                            if cur_group not in group_rank:
                                group_rank.append(cur_group)
                            break
                    else:
                        # If no match found, add a placeholder
                        group_rank.append("Unknown")
                        # print(f"[DEBUG RCA_ACC] {idx=} {item['metric']} -> Unknown")

                # Find position of ground truth in predicted rank_list
                mrr_reward = 0.0
                for i, item in enumerate(group_rank):
                    if item == gt_root_cause_group:
                        # Calculate MRR: 1/position (1-indexed)
                        mrr_reward = 1.0 / (i + 1)
                        break
            else:
                gt_root_cause_group = sol["root_cause"]
                group_rank = []
                for item in content["rank_list"]:
                    ans_group = None
                    for comp in group:
                        if match_metric_name(item["component"], comp):
                            ans_group = comp
                            break
                    if ans_group is None:
                        ans_group = "Unknown"
                    if ans_group not in group_rank:
                        group_rank.append(ans_group)
                # Find position of ground truth in predicted rank_list
                mrr_reward = 0.0
                for i, item in enumerate(group_rank):
                    if item in sol["root_cause"]:
                        # Calculate MRR: 1/position (1-indexed)
                        mrr_reward = 1.0 / (i + 1)
                        break

            if mrr_reward < 1.0:
                cur_feedbacks.append(f'Wrong root cause. Expected root cause metric in component {gt_root_cause_group}, got {group_rank[0] if len(group_rank) else None}.')

        except Exception as err:
            print(f"[rca_accuracy_reward] Failed to calculate MRR: {content} ({err})")
            mrr_reward = 0.0

        rewards.append(mrr_reward)
        feedback_prompts.append(cur_feedbacks)
        idx += 1

    return rewards, feedback_prompts


def rca_thinking_format_reward(completions, **kwargs):
    """RCA Reward function that match the thinking format tags (## Step 1 to ## Step 4)"""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    feedback_prompts = []

    for content in contents:
        # Try parse content
        try:
            # Extract items between <think> and <think/>
            thinking_text = re.search(r'<think>(.*?)</think>', content, re.DOTALL).group(1)
        except Exception as err:
            rewards.append(0.0)
            feedback_prompts.append([])
            continue

        cur_feedbacks = []

        try:
            # Try parse the thinking format in the text
            cur_matched = 0
            required_text = ['## Step 1', '## Step 2', '## Step 3', '## Step 4']
            for text in required_text:
                if text not in thinking_text:
                    cur_feedbacks.append(f"Missing {text} in the thinking process.")
                else:
                    cur_matched += 1

            thinking_format_reward = cur_matched / len(required_text)

            # If the thinking score is too low, prompt to add more causal review
            if thinking_format_reward < 0.1:
                cur_feedbacks.append("No thinking steps found! Please explicity show your thinking steps use format in: ## Step 1, ## Step 2, ## Step 3, ## Step 4.")
        except Exception as err:
            print(f"[rca_thining_format_reward] Failed to parse content thinking: {content} ({err})")
            thinking_format_reward = 0.0

        rewards.append(thinking_format_reward)
        feedback_prompts.append(cur_feedbacks)

    return rewards, feedback_prompts


def rca_propagation_reward(completions, solution, **kwargs):
    """
    RCA Reward function that checks the accuracy of upstream and type fields for each metric.
    """
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    feedback_prompts = []
    
    for content, sol in zip(contents, solution):
        # Try parse content
        try:
            # Extract items between <answer> and </answer>
            # Extract items between <answer> and </answer>
            if '<answer>' in content and '</answer>' in content:
                content = content.split('<answer>')[1].split('</answer>')[0]
            content = content.strip().replace('```json', '').replace('```', '')
            if '{' in content and '}' in content:
                content = content[content.rfind('{'):content.rfind('}') + 1]
            content = parse_llm_json(content, special_words=['metric', 'component', 'conclusion', ',', ':', '\n', '}', '{'])
        except Exception as err:
            rewards.append(0.0)
            feedback_prompts.append([])
            continue

        sol = json.loads(sol)
    
        # Reward feedback prompt
        cur_feedbacks = []

        # Calculate propagation reward
        try:            
            # Track accuracy for each metric
            total_metrics = len(sol["rank_list"])
            total_accuracy = 0.0
            
            # Build upstream_dict for each metric according to the rank list
            upstream_dict = {}
            for i, item in enumerate(sol["rank_list"]):
                metric_name = item["metric"]
                upstream_dict[metric_name] = set()

                if "upstream" in item and len(item["upstream"]) > 0:
                    upstream_dict[metric_name].add(item["upstream"])

                    if item["upstream"] in upstream_dict:
                        upstream_dict[metric_name].update(upstream_dict[item["upstream"]])

            for sol_metric in sol["rank_list"]:
                metric_name = sol_metric["metric"]
                
                # Check if metric exists in content
                if any(match_metric_name(metric_name, content_metric["metric"]) for content_metric in content["rank_list"]):
                    # Find the matching content metric
                    content_metric = next(item for item in content["rank_list"] 
                                         if match_metric_name(metric_name, item["metric"]))
                    
                    # Check type accuracy
                    if "type" not in sol_metric:
                        type_accuracy = 1.0
                    else:
                        type_accuracy = 1.0 if sol_metric["type"] == content_metric["type"] else 0.0
                    
                    # Check upstream accuracy
                    if "upstream" not in sol_metric:
                        upstream_accuracy = 1.0
                    else:
                        # We use a improved version of match_metric, search for the upstream
                        if (len(upstream_dict[metric_name]) == 0 and content_metric["upstream"] == "") or any(match_metric_name(upstream, content_metric["upstream"]) for upstream in upstream_dict[metric_name]):
                            upstream_accuracy = 1.0
                        else:
                            upstream_accuracy = 0.0
                        # upstream_accuracy = 1.0 if match_metric_name(sol_metric["upstream"], content_metric["upstream"]) else 0.0

                    # Add feedback prompt for wrong case
                    # if type_accuracy < 1.0:
                    #     cur_feedbacks.append(f"Wrong type for metric {metric_name}. Expected {sol_metric['type']}, got {content_metric['type']}.")
                    if upstream_accuracy < 1.0:
                        cur_info = ""
                        matched_upstream_answer = ""
                        matched_upstream_type = ""
                        for ans_metric in content["rank_list"]:
                            if match_metric_name(ans_metric["metric"], content_metric["upstream"]):
                                matched_upstream_answer = ans_metric["metric"]
                                matched_upstream_type = ans_metric['type']
                                break

                        if len(matched_upstream_answer) and matched_upstream_answer in upstream_dict and metric_name in upstream_dict[matched_upstream_answer]:
                            cur_info += f"However, {matched_upstream_answer} is actually a downstream metric of {metric_name}. "
                        elif len(matched_upstream_type) and matched_upstream_type in ["noise", "normal"]:
                            cur_info += f"However, {matched_upstream_answer} is actually a {matched_upstream_type} metric, not caused by the failure. "
                        elif sol_metric['type'] in ["noise", "normal"]:
                            cur_info += f"However, {metric_name} itself is actually a {sol_metric['type']} metric, not caused by the failure. So it will not have intervention upstream. "

                        if cur_info:
                            cur_feedbacks.append(f"Wrong upstream for metric {metric_name}. Expected {sol_metric['upstream'] if len(sol_metric['upstream']) else 'None'}, got {content_metric['upstream'] if len(content_metric['upstream']) else 'None'}. {cur_info}")
                    
                    # Average accuracy for this metric
                    metric_accuracy = (type_accuracy + upstream_accuracy) / 2.0
                    total_accuracy += metric_accuracy
                else:
                    # Metric not found in content, accuracy is 0
                    total_accuracy += 0.0
            
            # Calculate average accuracy across all metrics
            propagation_reward = total_accuracy / total_metrics if total_metrics > 0 else 0.0
                
        except Exception as err:
            print(f"[rca_propagation_reward] Failed to calculate propagation reward: {err}")
            propagation_reward = 0.0

        rewards.append(propagation_reward)
        feedback_prompts.append(cur_feedbacks)

    return rewards, feedback_prompts

def format_reward(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
    pattern = r"^\s*<think>\n*.*?\n*</think>\s*<answer>\n*.*?\n*</answer>\s*$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    reward = [1.0 if match else 0.0 for match in matches]
    # print(f"[format_reward={reward}]-------------------------------------------------- \n{completions=}\n--------------------------------------------------------")
    return reward

def rca_format_reward(completions, solution, **kwargs):
    """RCA Reward function"""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        # Try parse content
        try:
            # Extract items between <answer> and <answer/>
            content = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL).group(1)
            content = json.loads(content)
        except Exception as err:
            rewards.append(0.0)
            continue

        sol = json.loads(sol)
        
        # (Part 1) Basic reward
        basic_reward = 1.0

        # (Part 2) Key reward
        try: 
            key_reward = 0.0
            for key in list(sol.keys()):
                if key in content:
                    key_reward += 1 / len(sol.keys())
        except Exception as err:
            print(f"[rca_format_reward] Failed to parse content: {content} ({err})")
            key_reward = 0.0

        # (Part 3) Propagation dict format result
        propagation_reward = 0.0
        try:
            if 'rank_list' not in sol:
                propagation_reward = 1.0
            elif 'rank_list' in content:
                total_metrics = len(sol['rank_list'])
                for sol_metric in sol['rank_list']:
                    # Check if there is any match
                    for ans_metric in content['rank_list']:
                        if sol.get("level", "metric") == "metric":
                            if match_metric_name(sol_metric['metric'], ans_metric['metric']):
                                propagation_reward += 1.0 / total_metrics
                                break
                        else:
                            if match_metric_name(sol_metric['component'], ans_metric['component']):
                                propagation_reward += 1.0 / total_metrics
                                break
                    else:
                        continue
        except Exception as err:
            print(f"[rca_format_reward] Failed to parse content propagation: {content} ({err})")
            propagation_reward = 0.0

        rewards.append((basic_reward + key_reward + propagation_reward) / 3.0)

    return rewards

def rca_causal_review_reward(completions, **kwargs):
    # Check if there are casual review thinking with specific mode
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    feedback_prompts = []
    for content in contents:
        # Try parse content
        try:
            # Extract items between <think> and <think/>
            thinking_text = re.search(r'<think>(.*?)</think>', content, re.DOTALL).group(1)
        except Exception as err:
            rewards.append(0.0)
            feedback_prompts.append([])
            continue

        cur_feedbacks = []

        try:
            # Try parse causal review from thinking
            # Define regex patterns with flexibility for variations
            review_pattern = re.compile(r'causal\s*review:\s*reconsidering\s+(\w+)\s+causes\s+(\w+)', re.IGNORECASE)
            conclusion_pattern = re.compile(r'causal\s*review\s*conclusion:', re.IGNORECASE)
            
            # Alternative patterns for different phrasings
            alt_review_pattern = re.compile(r'(?:reconsidering|reviewing|reexamining)\s+(?:the\s+)?(?:causal|causality)(?:\s+between|relationship)?\s+(\w+)(?:\s+and|\s*->\s*|\s+causes\s+)(\w+)', re.IGNORECASE)
            alt_conclusion_pattern = re.compile(r'(?:causal|causality)(?:\s+review|\s+reconsideration)(?:\s+results|\s+outcome|\s+conclusion)', re.IGNORECASE)
            
            # Count occurrences
            review_count = len(review_pattern.findall(thinking_text))
            review_count += len(alt_review_pattern.findall(thinking_text))
            
            conclusion_count = len(conclusion_pattern.findall(thinking_text))
            conclusion_count += len(alt_conclusion_pattern.findall(thinking_text))
            
            # Calculate normalized scores (cap at 6 occurrences for max score)
            review_score = min(review_count / 6, 1.0)
            conclusion_score = min(conclusion_count / 6, 1.0)
            
            # Combined weighted score
            thinking_reward = (0.6 * review_score + 0.4 * conclusion_score)

            # If the thinking score is too low, prompt to add more causal review
            if thinking_reward < 0.1:
                cur_feedbacks.append("Causal review thinking is missing or insufficient. Please explicitly review the causes in your thiking steps following the format.")
        except Exception as err:
            print(f"[rca_causal_review_reward] Failed to parse content thinking: {content} ({err})")
            thinking_reward = 0.0

        rewards.append(thinking_reward)

    return rewards, feedback_prompts

def rca_propagation_review_reward(completions, **kwargs):
    # Check if there are propagation review thinking with specific mode
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    feedback_prompts = []
    for content in contents:
        # Try parse content
        try:
            # Extract items between <think> and <think/>
            thinking_text = re.search(r'<think>(.*?)</think>', content, re.DOTALL).group(1)
        except Exception as err:
            rewards.append(0.0)
            feedback_prompts.append([])
            continue

        cur_feedbacks = []

        try:
            # Try parse propagation review from thinking
            ## Define regex patterns with flexibility for variations
            graph_start_pattern = re.compile(r'propagation\s*graph\s*review:\s*start', re.IGNORECASE)
            graph_complete_pattern = re.compile(r'propagation\s*graph\s*review:\s*complete', re.IGNORECASE)
            
            # Alternative patterns for different phrasings
            alt_graph_start_pattern = re.compile(r'(?:reviewing|analyzing|examining)\s+(?:the\s+)?(?:propagation|causal)\s*(?:graph|chain|network)', re.IGNORECASE)
            alt_graph_complete_pattern = re.compile(r'(?:propagation|causal)(?:\s+graph|\s+chain)(?:\s+review|\s+analysis)(?:\s+complete|\s+finished|\s+done)', re.IGNORECASE)
            
            # Count occurrences of graph reviews
            graph_review_count = 0
            
            # Check for complete graph reviews (start + complete)
            starts = graph_start_pattern.findall(thinking_text) + alt_graph_start_pattern.findall(thinking_text)
            completes = graph_complete_pattern.findall(thinking_text) + alt_graph_complete_pattern.findall(thinking_text)
            
            graph_review_count = min(len(starts), len(completes))
            
            # Calculate normalized score (cap at 6 complete reviews for max score)
            thinking_reward = min(graph_review_count / 6, 1.0)

            # If the thinking score is too low, prompt to add more causal review
            if thinking_reward < 0.1:
                cur_feedbacks.append("Propagation review thinking is missing or insufficient. Please explicitly review the causes in your thiking steps following the format.")
        except Exception as err:
            print(f"[rca_propagation_review_reward] Failed to parse content thinking: {content} ({err})")
            thinking_reward = 0.0

        rewards.append(thinking_reward)

    return rewards, feedback_prompts

def tag_count_reward(completions, **kwargs) -> list[float]:
    """Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.

    Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90
    """

    def count_tags(text: str) -> float:
        count = 0.0
        if text.count("<think>\n") == 1:
            count += 0.25
        if text.count("\n</think>\n") == 1:
            count += 0.25
        if text.count("\n<answer>\n") == 1:
            count += 0.25
        if text.count("\n</answer>") == 1:
            count += 0.25
        return count

    contents = [completion[0]["content"] for completion in completions]
    return [count_tags(c) for c in contents]


def reasoning_steps_reward(completions, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [len(re.findall(pattern, content)) for content in completion_contents]

    # Magic number 3 to encourage 3 steps and more, otherwise partial reward
    return [min(1.0, count / 3) for count in matches]


def rca_len_reward(completions, **kwargs):
    r"""Reward function for RCA thinking length
    """
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content in contents:
        # Try parse content
        try:
            # Extract items between <answer> and <answer/>
            think_steps = re.search(r'<think>(.*?)</think>', content, re.DOTALL).group(1)
        except Exception as err:
            rewards.append(0.0)
            continue

        # Calculate reward
        len_reward = min(1.0, len(think_steps) / 1200)

        rewards.append(len_reward)

    return rewards


def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
    """Compute length-based rewards to discourage overthinking and promote token efficiency.

    Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599

    Args:
        completions: List of model completions
        solution: List of ground truth solutions

    Returns:
        List of rewards where:
        - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
        - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
    """
    contents = [completion[0]["content"] for completion in completions]

    # First check correctness of answers
    correctness = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) == 0:
            # Skip unparseable examples
            correctness.append(True)  # Treat as correct to avoid penalizing
            print("Failed to parse gold solution: ", sol)
            continue

        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed=True,
                        units=True,
                    ),
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )
        correctness.append(verify(answer_parsed, gold_parsed))

    # Calculate lengths
    lengths = [len(content) for content in contents]
    min_len = min(lengths)
    max_len = max(lengths)

    # If all responses have the same length, return zero rewards
    if max_len == min_len:
        return [0.0] * len(completions)

    rewards = []
    for length, is_correct in zip(lengths, correctness):
        lambda_val = 0.5 - (length - min_len) / (max_len - min_len)

        if is_correct:
            reward = lambda_val
        else:
            reward = min(0, lambda_val)

        rewards.append(float(reward))

    return rewards


def get_cosine_scaled_reward(
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    def cosine_scaled_reward(completions, solution, **kwargs):
        """Reward function that scales based on completion length using a cosine schedule.

        Shorter correct solutions are rewarded more than longer ones.
        Longer incorrect solutions are penalized less than shorter ones.

        Args:
            completions: List of model completions
            solution: List of ground truth solutions

        This function is parameterized by the following arguments:
            min_value_wrong: Minimum reward for wrong answers
            max_value_wrong: Maximum reward for wrong answers
            min_value_correct: Minimum reward for correct answers
            max_value_correct: Maximum reward for correct answers
            max_len: Maximum length for scaling
        """
        contents = [completion[0]["content"] for completion in completions]
        rewards = []

        for content, sol in zip(contents, solution):
            gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
            if len(gold_parsed) == 0:
                rewards.append(1.0)  # Skip unparseable examples
                print("Failed to parse gold solution: ", sol)
                continue

            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed=True,
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            is_correct = verify(answer_parsed, gold_parsed)
            gen_len = len(content)

            # Apply cosine scaling based on length
            progress = gen_len / max_len
            cosine = math.cos(progress * math.pi)

            if is_correct:
                min_value = min_value_correct
                max_value = max_value_correct
            else:
                # Swap min/max for incorrect answers
                min_value = max_value_wrong
                max_value = min_value_wrong

            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))

        return rewards

    return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
    """
    Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

    Args:
    ngram_size: size of the n-grams
    max_penalty: Maximum (negative) penalty for wrong answers
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")

    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])

    def repetition_penalty_reward(completions, **kwargs) -> float:
        """
        reward function the penalizes repetitions
        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

        Args:
            completions: List of model completions
        """

        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for completion in contents:
            if completion == "":
                rewards.append(0.0)
                continue
            if len(completion.split()) < ngram_size:
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            for ng in zipngram(completion, ngram_size):
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty
            rewards.append(reward)
        return rewards

    return repetition_penalty_reward


def extract_code(completion: str) -> str:
    pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
    matches = pattern.findall(completion)
    extracted_answer = matches[-1] if len(matches) >= 1 else ""
    return extracted_answer


def code_reward(completions, **kwargs) -> list[float]:
    """Reward function that evaluates code snippets using the E2B code interpreter.

    Assumes the dataset contains a `verification_info` column with test cases.
    """
    if not is_e2b_available():
        raise ImportError(
            "E2B is not available and required for this reward function. Please install E2B with "
            "`pip install e2b-code-interpreter` and add an API key to a `.env` file."
        )

    rewards = []
    # TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
    try:
        """Returns a reward function that evaluates code snippets in a sandbox."""
        evaluation_script_template = """
        import subprocess
        import json

        def evaluate_code(code, test_cases):
            passed = 0
            total = len(test_cases)
            exec_timeout = 5

            for case in test_cases:
                process = subprocess.run(
                    ["python3", "-c", code],
                    input=case["input"],
                    text=True,
                    capture_output=True,
                    timeout=exec_timeout
                )

                if process.returncode != 0:  # Error in execution
                    continue

                output = process.stdout.strip()
                if output.strip() == case["output"].strip():
                    passed += 1

            success_rate = (passed / total)
            return success_rate

        code_snippet = {code}
        test_cases = json.loads({test_cases})

        evaluate_code(code_snippet, test_cases)
        """
        code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
        verification_info = kwargs["verification_info"]
        scripts = [
            evaluation_script_template.format(
                code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
            )
            for code, info in zip(code_snippets, verification_info)
        ]
        with Sandbox(timeout=30, request_timeout=3) as sbx:
            for script in scripts:
                execution = sbx.run_code(script, language=verification_info["language"])
                try:
                    output = float(execution.text)
                except (TypeError, ValueError):
                    output = 0.0
                rewards.append(output)
    except Exception as e:
        print(f"Error from E2B executor: {e}")
        rewards = [0.0] * len(completions)
    return rewards


def get_code_format_reward(language: str = "python"):
    """Format reward function specifically for code responses.

    Args:
        language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
    """
    pattern = rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{language}.*?```.*?\n</answer>$"

    def code_format_reward(completions, **kwargs):
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]

    return code_format_reward

def rlvr_accuracy_reward(completions, solution, question_type, **kwargs):
    """RCA Reward function that uses MRR to evaluate the rank of the ground truth root cause."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    rlvr_details = []
    
    idx = 0
    for content, sol, cur_type in zip(contents, solution, question_type):
        try:
            # Extract items between <answer> and </answer>
            if '<answer>' in content and '</answer>' in content:
                content = content.split('<answer>')[1].split('</answer>')[0]
            content = content.strip().replace('```json', '').replace('```', '')
            if '\\answer{' in content and '}' in content:
                content = content[content.rfind('\\answer{') + len('\\answer{'):content.rfind('}')]
            sol = sol.strip()
        except Exception as err:
            rewards.append(0.0)
            continue
        
        cur_reward = 0.0

        # Calculate MRR reward
        cur_detail = None
        try:
            if cur_type == "multiple_choice":
                cur_reward = float(sol.lower().strip() == content.lower().strip())
            elif cur_type == "true_false":
                if sol.lower() == 'true' or sol.lower() == 'yes' or sol.lower() == 't' or sol.lower() == '是':
                    cur_reward = float(content.lower() in ['true', 'yes', 't', '是'])
                elif sol.lower() == 'false' or sol.lower() == 'no' or sol.lower() == 'f' or sol.lower() == '否' or sol.lower() == '不是':
                    cur_reward = float(content.lower() in ['false', 'no', 'f', '不是', '否'])
                else:
                    raise ValueError(f"Unsupported true/false answer: {sol[-50:]}")
            elif cur_type == "numerical":
                sol = float(sol)
                content = float(content)
                # Check
                if abs(sol) < 0.5:
                    if abs(content) < 0.5:
                        cur_reward = 1.0
                    else:
                        cur_reward = 0.0
                else:
                    cur_reward = max(0.0, min(1.0, 1.0 - abs(content - sol) / abs(sol)))
            elif cur_type == "tsad":
                # TSAD: compute F1-score between predicted anomaly intervals and label intervals
                # Tolerance: both start and end within +/-3 counts as a perfect match for that interval
                # If both predict/label say no anomaly: F1 = 1.0; if mismatch: F1 = 0.0
                def _parse_intervals_from_json_or_text(s: str):
                    # Try JSON first
                    try:
                        obj = json.loads(s)
                        if isinstance(obj, str):
                            # Maybe nested JSON string
                            obj = json.loads(obj)
                    except Exception:
                        obj = None
                    intervals = []
                    has_none_flag = False
                    if isinstance(obj, dict):
                        # JSON form
                        ai = obj.get("anomaly_intervals")
                        if isinstance(ai, list):
                            for it in ai:
                                if isinstance(it, (list, tuple)) and len(it) == 2:
                                    try:
                                        a, b = int(it[0]), int(it[1])
                                        if a > b:
                                            a, b = b, a
                                        intervals.append((a, b))
                                    except Exception:
                                        pass
                                elif isinstance(it, str):
                                    m = re.findall(r"(\d+)\s*-\s*(\d+)", it)
                                    for a, b in m:
                                        a, b = int(a), int(b)
                                        if a > b:
                                            a, b = b, a
                                        intervals.append((a, b))
                        # Explicit none flags
                        has_anom = obj.get("has_anomaly")
                        ans = obj.get("answer")
                        if has_anom is False or (isinstance(ans, str) and ans.strip().lower() in ["none", "no", "no anomaly", "no anomalies"]):
                            has_none_flag = True
                    else:
                        # Text form
                        low = s.strip().lower()
                        if any(tok in low for tok in ["none", "no anomaly", "no anomalies", "no anomal"]):
                            has_none_flag = True
                        for a, b in re.findall(r"(\d+)\s*-\s*(\d+)", s):
                            a, b = int(a), int(b)
                            if a > b:
                                a, b = b, a
                            intervals.append((a, b))
                    # Deduplicate and sort
                    intervals = sorted(set(intervals))
                    return intervals, has_none_flag

                def _event_f1_with_tolerance(pred, label, tol=3):
                    # Point-wise F1 calculation with tolerance-based boundary alignment
                    if len(pred) == 0 and len(label) == 0:
                        return 1.0
                    
                    # Find the total time range to create point-wise arrays
                    all_points = []
                    for ls, le in label:
                        all_points.extend([ls, le])
                    for ps, pe in pred:
                        all_points.extend([ps, pe])
                    
                    if not all_points:
                        return 1.0
                    
                    min_time, max_time = min(all_points), max(all_points)
                    total_points = max_time - min_time + 1
                    
                    # Create label point-wise array
                    label_points = [0] * total_points
                    for ls, le in label:
                        for t in range(ls - min_time, le - min_time + 1):
                            if 0 <= t < total_points:
                                label_points[t] = 1
                    
                    # Create prediction point-wise array with tolerance-based alignment
                    pred_points = [0] * total_points
                    aligned_pred = []
                    
                    # For each prediction segment
                    for ps, pe in pred:
                        # Check if this pred segment matches any label segment within tolerance
                        aligned = False
                        for ls, le in label:
                            if abs(ps - ls) <= tol and abs(pe - le) <= tol:
                                # Align pred segment to label segment boundaries
                                aligned_pred.append((ls, le))
                                aligned = True
                                break
                        
                        if not aligned:
                            # Use original pred segment
                            aligned_pred.append((ps, pe))
                    
                    # Fill pred_points based on aligned segments
                    for ps, pe in aligned_pred:
                        for t in range(ps - min_time, pe - min_time + 1):
                            if 0 <= t < total_points:
                                pred_points[t] = 1
                    
                    # Calculate point-wise F1
                    tp = sum(1 for i in range(total_points) if label_points[i] == 1 and pred_points[i] == 1)
                    fp = sum(1 for i in range(total_points) if label_points[i] == 0 and pred_points[i] == 1)
                    fn = sum(1 for i in range(total_points) if label_points[i] == 1 and pred_points[i] == 0)
                    
                    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                    
                    return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0

                # Parse intervals
                label_intervals, label_none = _parse_intervals_from_json_or_text(sol)
                pred_intervals, pred_none = _parse_intervals_from_json_or_text(content)

                # No-anomaly cases
                if len(label_intervals) == 0:
                    cur_reward = 1.0 if (pred_none or len(pred_intervals) == 0) else 0.0
                    # cur_detail = {
                    #     "ability_type": "tsad",
                    #     "scores": {"cate": [cur_reward], "num": [], "reason": []},
                    # }
                else:
                    if pred_none and len(pred_intervals) == 0:
                        cur_reward = 0.0
                    else:
                        # Event-level F1 without one-to-one constraints
                        cur_reward = _event_f1_with_tolerance(pred_intervals, label_intervals, tol=3)           
            else:
                # Use the reward_qa to get the result
                reward_result = reward_qa(content, sol)
                all_scores = []
                cur_detail = {}

                for ability_type, (cate_score, num_score, reason_score, _) in reward_result.items():
                    if cate_score is not None and len(cate_score):
                        # logger.warning(f"{cate_score=}")
                        all_scores.append(np.mean(cate_score))
                    if num_score is not None and len(num_score):
                        # logger.warning(f"{num_score=}")
                        all_scores.append(np.mean(num_score))
                    if reason_score is not None and len(reason_score):
                        # logger.warning(f"{reason_score=}")
                        all_scores.append(np.mean(reason_score))
                    cur_detail = {
                        "ability_type": ability_type,
                        "scores": {
                            "cate": cate_score,
                            "num": num_score,
                            "reason": reason_score
                        }
                    }
                    
                # Average all the results
                cur_reward = float(np.mean(all_scores)) if len(all_scores) > 0 else 0.0
                # logger.warning(f"[rlvr_accuracy_reward] {idx=} {cur_reward=}")

            # Set cur_deailt if not set
            if cur_detail is None:
                cur_detail = {
                    "ability_type": cur_type,
                    "scores": {
                        "other": [cur_reward]
                    }
                }
        except Exception as err:
            logger.warning(f"[rlvr_accuracy_reward] Failed to calculate RLVR: {content[-50:]}")
            traceback.print_exc()
            cur_reward = 0.0

        rewards.append(cur_reward)
        rlvr_details.append(cur_detail)
        idx += 1

    return rewards, None, rlvr_details

def tool_count_reward(completions, solution, question_type, timeseries, **kwargs):
    """Limit the tool count according to the question type."""
    MAX_COUNT = 8
    MIN_COUNT = 0
    THRESHOLD = 3

    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    
    idx = 0
    for content, sol, cur_type in zip(contents, solution, question_type):
        try:
            # Extract items between <answer> and </answer>
            raw_content = content
            if '<answer>' in content and '</answer>' in content:
                content = content.split('<answer>')[1].split('</answer>')[0]
            content = content.strip().replace('```json', '').replace('```', '')
            if '\\answer{' in content and '}' in content:
                content = content[content.rfind('\\answer{') + len('\\answer{'):content.rfind('}')]
            sol = sol.strip()
        except Exception as err:
            rewards.append(0.0)
            continue
        
        cur_reward = 0.0

        # Calculate MRR reward
        cur_detail = None
        try:
            # Get the count of timeseries and tool call
            num_ts = len(timeseries) if timeseries is not None and type(timeseries) == list else 0
            target_cnt = MAX_COUNT
            num_tool_call = raw_content.count('<tool_call>')

            if cur_type == "multiple_choice":
                target_cnt = min(MAX_COUNT, 1 + num_ts)
            elif cur_type == "true_false":
                target_cnt = min(MAX_COUNT, 1 + num_ts)
            elif cur_type == "numerical":
                target_cnt = min(MAX_COUNT, 1 + num_ts)
            elif cur_type == "tsad":
                # Get the count of anomaly intervals
                num_anomaly_intervals = len(json.loads(sol).get('anomaly_intervals', [0]))
                target_cnt = min(MAX_COUNT, 1 + num_anomaly_intervals)
                MIN_COUNT = 1
            else:
                label = json.loads(sol)
                ability_type = label['ability_type']

                # Types
                if ability_type in ['noise', 'season']:
                    target_cnt = min(MAX_COUNT, 2)
                elif ability_type in ['trend']:
                    target_cnt = 0
                elif 'correlation' in ability_type:
                    target_cnt = min(MAX_COUNT, 2)
                elif ability_type == 'local':
                    local_cnt = len(label['attribute'])
                    target_cnt = min(MAX_COUNT, 1 + local_cnt)
                    MIN_COUNT = 1
                elif 'cluster' in ability_type:
                    target_cnt = min(MAX_COUNT, num_ts + 1)
                    MIN_COUNT = 1
                else:
                    target_cnt = min(MAX_COUNT, num_ts + 1)
                    MIN_COUNT = 1
            
            # Get reward
            if num_tool_call <= target_cnt:
                cur_reward = 1.0
            elif num_tool_call <= target_cnt + THRESHOLD:
                cur_reward = max(0.0, 1.0 - (num_tool_call - target_cnt) / THRESHOLD)
            else:
                cur_reward = 0.0

            if num_tool_call < MIN_COUNT:
                cur_reward = 0.0

            # Debug
            # logger.warning(f"[tool_count_reward] {idx=} {cur_type=} {num_ts=} {sol=} | [{target_cnt=}] [{num_tool_call=}] [{cur_reward=}]")
        except Exception as err:
            logger.warning(f"[tool_count_reward] Failed to calculate tool count target")
            traceback.print_exc()
            cur_reward = 0.0

        rewards.append(cur_reward)
        idx += 1

    return rewards
