from collections import Counter
import json
import re
import json5

def extract_json_from_markdown(markdown_text):
    # First try to find JSON in standard markdown code blocks
    patterns = [
        # Standard markdown code block with json label
        r"```(?:json\\n|json\s*)?\n(.*?)\n```",
        # Code block without language specification
        r"```\n(.*?)\n```",
        # JSON wrapped in single backticks
        r"`({.*?})`",
        # Bare JSON object (fallback)
        r'({[\s\S]*"[^"]*"\s*:[\s\S]*})',
    ]

    for pattern in patterns:
        # Use DOTALL flag to match across multiple lines
        matches = re.finditer(pattern, markdown_text, re.DOTALL)

        for match in matches:
            json_str = match.group(1)
            # Try to parse the extracted text as JSON
            try:
                return json5.loads(json_str)
            except (json.JSONDecodeError, ValueError, UnicodeDecodeError):
                # This match didn't work, continue to next match
                continue
    try:
        # Try to parse the entire markdown text as JSON
        return json5.loads(markdown_text)
    except (json.JSONDecodeError, ValueError, UnicodeDecodeError):
        pass

    # If we couldn't find valid JSON with any pattern
    # print("No valid JSON data found in the markdown text")
    return None

def validate_format(text: str, tool_call_tag: str, tool_response_tag: str) -> tuple[bool, str]:
    # check if <think></think>, <answer></answer> is paired
    if text.count('<think>') != text.count('</think>'):
        return False, "<think> </think> not paired"
    
    if text.count('<think>') == 0 or text.count('</think>') == 0:
        return False, "<think> or </think> not found"
    
    if text.count('<answer>') != 1 or text.count('</answer>') != 1:
        return False, "<answer> or </answer> not found"
    
    if text.count(f'<{tool_call_tag}>') != text.count(f'</{tool_call_tag}>'):
        return False, f"<{tool_call_tag}> or </{tool_call_tag}> not paired"
    
    if text.count(f'<{tool_call_tag}>') == 0 or text.count(f'</{tool_call_tag}>') == 0:
        return False, f"<{tool_call_tag}> or </{tool_call_tag}> not found"
    
    # check the order of search/result
    current_pos = 0
    while True:
        search_pos = text.find(f'<{tool_call_tag}>', current_pos)
        if search_pos == -1:
            break
            
        result_pos = text.find(f'<{tool_response_tag}>', search_pos)
        search_end_pos = text.find(f'</{tool_call_tag}>', search_pos)
        result_end_pos = text.find(f'</{tool_response_tag}>', result_pos)
        
        if -1 in (result_pos, search_end_pos, result_end_pos):
            return False, "search/result tags are incomplete"
            
        if not (search_pos < search_end_pos < result_pos < result_end_pos):
            return False, "search/result tags are nested in the wrong order"
            
        current_pos = result_end_pos
    
    # check if \boxed{} is in the answer
    answer_start = text.find('<answer>')
    answer_end = text.find('</answer>')
    if answer_start > answer_end:
        return False, "<answer> must be before </answer>"
    
    return True, "format is correct"


def search_query_repeat_nums(text: str, tool_call_tag: str) -> int:
    # extract search queries from <tool_call>...</tool_call>
    # there are multiple search queries, surrounded by different <tool_call>...</tool_call> tags
    search_queries = []
    current_pos = 0
    while True:
        search_pos = text.find(f'<{tool_call_tag}>', current_pos)
        if search_pos == -1:
            break
        search_query = text[search_pos:text.find(f'</{tool_call_tag}>', search_pos)][len(f'<{tool_call_tag}>'):].strip()
        if tool_call_tag == 'tool_call':
            try:
                search_query = json.loads(search_query)['arguments']['query']
            except:
                pass
        search_queries.append(search_query)
        current_pos = text.find(f'</{tool_call_tag}>', search_pos)
    
    # check the number of repeated search queries
    search_queries_counter = Counter(search_queries)
    search_query_repeat_nums = sum([count-1 for count in search_queries_counter.values() if count > 1])
    return search_query_repeat_nums


def get_tool_call_times(text: str, tool_call_tag: str):
    # contain how many tool_call tags
    count_of_tool_call = min(text.count(f"</{tool_call_tag}>"), text.count(f"<{tool_call_tag}>"))
    return count_of_tool_call
