import re
import math
import string
import random

def map_m_to_range(m, n):
    if m == n == 0:
        return n
    
    if n == 0:
        return m
    
    return 2 * n * (m / (m + n))

def reward_function(norm_m, n, type="sine"):
    if norm_m == n == 0:
        return 1
    
    if n == 0: # if optimal tool call is 0 but current tool call is greater than 0, return the minimal reward
        return math.cos(math.pi * (norm_m / (2 * norm_m + 4))) # 0.1
    
    if type == "cosine":
        return math.cos(math.pi * (norm_m / (2 * n)))
    elif type == "sine":
        return math.sin(math.pi * (norm_m / (2 * n)))
    elif type == "linear":
        if  norm_m <= n:
            return norm_m
        else:
            return - (0.5) * norm_m
    else:
        raise ValueError("Invalid reward function for tool use.")

def reward_function_wo_optimal(m, type="basic"):
    if m == 0:
        return 1
    
    if type == "basic":
        # map norm_m to 0 -> pi/2
        return math.cos(math.pi * (m / (2 * m + 4))) # c = 1 control the smooth, it is very import to control the tool call times
    else:
        raise ValueError("Invalid reward function for tool use.")

def extract_search(solution_str, tool_num=1):
    """Extract the equation from the solution string."""
    # Remove everything before the first "Assistant:"
    # if "Assistant:" in solution_str:
    #     solution_str = solution_str.split("Assistant:", 1)[1]
    # elif "<|im_start|>assistant" in solution_str:
    #     solution_str = solution_str.split("<|im_start|>assistant", 1)[1]
    # else:
    #     return None

    solution_str = solution_str.replace('If I want to search, I should put the query between <search> and </search>.', '')  # for action explanation
    
    # answer match
    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.finditer(answer_pattern, solution_str, re.DOTALL)
    matches = list(match)

    if len(matches) > 2:
        # get position of last </answer>
        last_match = matches[-1]
        end_pos = last_match.end()
        solution_str = solution_str[:end_pos] # handle cases like "<answer> answer </answer> <search> query </search>"  if "<search> query </search>" as ending without answer, then the reward is 0, no effects
    else:
        solution_str = solution_str

    search_pattern = r'<search>(.*?)</search>'
    match = re.finditer(search_pattern, solution_str, re.DOTALL)
    matches = list(match)
    
    # If there are 0 or exactly 1 matches, return None
    if len(matches) <= 1:
        return 0.
    
    # If there are 2 or more matches, return the last one
    return len(matches) - 1

def compute_tool_reward(current_tool_calls, optimal_calls=None, alpha=0.5):
    tool_stat = {}
    # with search as tool
    # times_search = extract_search(solution_str)
    tool_stat["search"] = current_tool_calls

    # assume the optimal is 1
    if optimal_calls is None:
        tool_reward = reward_function_wo_optimal(current_tool_calls)
    else:
        normalized_tool_call = map_m_to_range(current_tool_calls, optimal_calls)
        tool_reward = reward_function(normalized_tool_call, optimal_calls)

    return alpha * tool_reward, tool_stat