import re
import os
import time
import json 
import torch
import argparse
from tqdm import tqdm
from pytz import timezone
from datetime import datetime
from datasets import load_dataset
from openai import OpenAI
from prompts import *
from bing_search import (
    jina_web_search,
    extract_snippet_with_context,
    fetch_page_content
)
from typing import List
from utils import Generator  # Import Generator
from typing import List, Dict, Any, Optional
import concurrent.futures


PROMPT = """**Task Instruction:**

You are a reasoning assistant equipped with web search capabilities to accurately answer the user's questions.

Follow these steps:
1. **Clearly identify** the specific information you need to answer the user's question.
2. **Perform a web search** for the required information by writing your queries as follows:
```
<|begin_search_queries|>
Your search queries here (multiple queries can be placed together seperated by ";\n")
<|end_search_queries|>
```
3. Review the provided search results.
4. If additional information is still required, repeat step 2 with new queries.
5. Once all relevant information has been gathered, use your reasoning abilities to synthesize a clear, concise, and accurate answer.

**Remember:**
* Clearly separate each search query.
* Combine multiple queries into a single search action when they can be run simultaneously.
* 
"""

# Set OpenAI's API key and API base to use vLLM's API server.
# openai_api_key = "EMPTY"
# openai_api_base = "http://0.0.0.0:8000/v1"
# client = OpenAI(
#     api_key=openai_api_key,
#     base_url=openai_api_base,
# )
print("Using OpenRouter")
client = OpenAI(
  base_url="https://openrouter.ai/api/v1",
  api_key="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",
)
jina_api_key = os.environ["JINA_API_KEY"]

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", "-d", required=True, choices=["musique", "browse_comp", "hle", "gpqa", "fanoutqa", "cwq", "hotpotqa", "med_browse_comp", "multihopqa", "frames"]) 
    parser.add_argument("--dataset_split", "-sp", default="train")
    parser.add_argument("--model_id_or_path", "-m", default="Qwen3-8B-sft")
    parser.add_argument("--apply_chat", type=bool, default=True)
    parser.add_argument("--use_jina", action="store_true", help="Whether to use jina for extracting text from urls")
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--max_doc_len", type=int, default=1024)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top_k", type=int, default=10, help="Retrieve the top k documents")
    parser.add_argument("--output_dir", type=str, default="./results")
    parser.add_argument("--start_idx", "-s", type=int)
    parser.add_argument("--end_idx", "-e", type=int)
    parser.add_argument("--api_model", "-am", default="qwen/qwen3-32b")
    parser.add_argument("--max_search_limit", "-msl", type=int, default=0)
    parser.add_argument("--max_turn_limit", "-mtl", type=int, default=0)
    parser.add_argument("--max_iteration", type=int, default=10)
    args = parser.parse_args()
    return args



def get_response_from_llm(
        messages: List[Dict[str, Any]],
        client: OpenAI,
        model: str,
        stream: Optional[bool] = False,
        temperature: Optional[float] = 0.7,
        top_p: Optional[float] = 0.8,
        depth: int = 0,
        query_for_the_prompt: str = None
):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            max_tokens=2048,
            temperature=temperature,
            top_p=top_p,
            extra_body={"top_k": 20,
                        "chat_template_kwargs": {"enable_thinking": False}}
        )
        if hasattr(response.choices[0].message, 'content') and response.choices[0].message.content:
            content = response.choices[0].message.content
        if query_for_the_prompt:
            return query_for_the_prompt, content.strip()
        else:
            return content.strip()
    
    except Exception as e:
        print(f"LLM API error: {e}")
        if "Input data may contain inappropriate content" in str(e):
            return ""
        if "Error code: 400" in str(e):
            return ""
        if depth < 512:
            time.sleep(0.5)
            return get_response_from_llm(messages=messages, client=client, model=model, temperature=temperature, top_p=top_p, depth=depth+1, query_for_the_prompt=query_for_the_prompt)
        raise e

def generate_webpage_to_reasonchain_batch(
        args,
        generator,  # Use Generator instance
        original_questions: List[str],
        prev_reasonings: List[str],
        search_queries: List[List[str]],  # list of queries per sequence
        documents: Dict[str, str],       # { search_query: formatted_doc_string }
        dataset_name: str,
        max_tokens: int = 2048,
        coherent: bool = False,
    ) -> List[Dict[str, str]]:  # returns list of list of outputs per query per sequence
    
    
    # Flatten all prompts: one prompt per query
    def get_reasonchain_prompt(prev_reasoning, query, document):
        prompt = prompt_for_webpage_to_reasonchain_instruction.format(
            prev_reasoning=prev_reasoning, 
            search_query=query, 
            document=document
        )
        prompt = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
        return prompt
    
    all_user_prompts = {}
    for prev_reasoning, seq_queries in zip(prev_reasonings, search_queries):
        for query in seq_queries:
            prompt = get_reasonchain_prompt(prev_reasoning, query, documents[query])
            all_user_prompts[query] = prompt

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [executor.submit(get_response_from_llm, prompt, client, args.api_model, temperature=0.7, top_p=0.8, query_for_the_prompt=q) for q, prompt in all_user_prompts.items()]
        outputs = [f.result() for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures))]
        
    extracted_infos = {query: extract_answer(output, mode='infogen') for query, output in outputs}
    return extracted_infos

def extract_answer(output, mode='gen'):
    extracted_text = ''
    if mode == 'codegen':
        # Extract the code between ```python and ```
        pattern = r'```python\s*(.*?)\s*```'
        matches = re.findall(pattern, output, re.DOTALL | re.IGNORECASE)
        if matches:
            extracted_text = matches[-1].strip()  # Take the last match
    elif mode == 'infogen':
        # Extract content after **Final Information** or **Modified Reasoning Steps**
        pattern_info = "**Final Information**"
        pattern_step = "**Modified Reasoning Steps**"
        if pattern_info in output:
            extracted_text = output.split(pattern_info)[-1].replace("\n","").strip("```").strip()
        elif pattern_step in output:
            extracted_text = output.split(pattern_step)[-1].strip("```").strip()
        else:
            extracted_text = "No helpful information found."
    else:
        # Existing extraction logic for 'gen' and 'choose' modes
        pattern = r'\\boxed\{(.*)\}'
        matches = re.findall(pattern, output)
        if matches:
            extracted_text = matches[-1]  # Take the last match
            if mode in ['choose', 'qa']:
                # Handle 'choose' mode
                inner_pattern = r'\\text\{(.*)\}'
                inner_matches = re.findall(inner_pattern, extracted_text)
                if inner_matches:
                    extracted_text = inner_matches[-1]  # Take the last match
                extracted_text = extracted_text.strip("()")
    return extracted_text


# Function to extract text between two tags
def extract_between(text: str, begin_search_query_token: str, end_search_query_token: str) -> List[str]:
    
    # Set pattern 
    pattern_begin = begin_search_query_token.replace('|', '\\|')
    pattern = begin_search_query_token.replace('|', '\\|') + "([\s\S]*?)" + end_search_query_token.replace('|', '\\|')
    
    if len(re.findall(pattern_begin, text)) > 1:
        return
    output = re.findall(pattern, text)
    if len(output) == 0:
        return
    
    # Split output
    output = output[0].strip().strip("\n")
    output = [i.strip().strip(";") for i in output.split("\n")]
    if len(output) == 1 and ";" in output[0]:
        output = output[0].split(";")
        output = [o.strip() for o in output]
    
    # Return output when 
    output = [o for o in output if len(o) > 0]
    return output


def replace_recent_steps(origin_str, replace_str):
    """
    Replaces specific steps in the original reasoning steps with new steps.
    If a replacement step contains "DELETE THIS STEP", that step is removed.

    Parameters:
    - origin_str (str): The original reasoning steps.
    - replace_str (str): The steps to replace or delete.

    Returns:
    - str: The updated reasoning steps after applying replacements.
    """

    def parse_steps(text):
        """
        Parses the reasoning steps from a given text.

        Parameters:
        - text (str): The text containing reasoning steps.

        Returns:
        - dict: A dictionary mapping step numbers to their content.
        """
        step_pattern = re.compile(r"Step\s+(\d+):\s*")
        steps = {}
        current_step_num = None
        current_content = []

        for line in text.splitlines():
            step_match = step_pattern.match(line)
            if step_match:
                # If there's an ongoing step, save its content
                if current_step_num is not None:
                    steps[current_step_num] = "\n".join(current_content).strip()
                current_step_num = int(step_match.group(1))
                content = line[step_match.end():].strip()
                current_content = [content] if content else []
            else:
                if current_step_num is not None:
                    current_content.append(line)
        
        # Save the last step if any
        if current_step_num is not None:
            steps[current_step_num] = "\n".join(current_content).strip()
        
        return steps

    # Parse the original and replacement steps
    origin_steps = parse_steps(origin_str)
    replace_steps = parse_steps(replace_str)

    # Apply replacements
    for step_num, content in replace_steps.items():
        if "DELETE THIS STEP" in content:
            # Remove the step if it exists
            if step_num in origin_steps:
                del origin_steps[step_num]
        else:
            # Replace or add the step
            origin_steps[step_num] = content

    # Sort the steps by step number
    sorted_steps = sorted(origin_steps.items())

    # Reconstruct the reasoning steps as a single string
    new_reasoning_steps = "\n\n".join([f"{content}" for num, content in sorted_steps])

    return new_reasoning_steps


def main():
    args = get_args()
    
    BEGIN_SEARCH_QUERY = "<|begin_search_queries|>"
    END_SEARCH_QUERY = "<|end_search_queries|>"
    BEGIN_SEARCH_RESULT = "<|begin_search_results|>"
    END_SEARCH_RESULT = "<|end_search_results|>"
    
    dataset_name = args.dataset_name.split("/")[-1]
    model_name = args.model_id_or_path.split("/")[-1]
    output_dir = f"{args.output_dir}/{args.dataset_name}"
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/{args.start_idx}_{args.end_idx}.json"
    # if os.path.exists(output_path):
    #     print(f"Output file already exists: {output_path}")
    #     return
    print(f"******************************\nOutput file: {output_path}\n******************************")
    
    # Load dataset
    if args.dataset_name == "browse_comp":
        with open("reasoning_rag/datasets/BROWSE_COMP/test.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "med_browse_comp":
        with open("reasoning_rag/datasets/MEDBROWSECOMP/test.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "hle":
        with open("reasoning_rag/datasets/HLE/test.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "gpqa":
        with open("reasoning_rag/datasets/GPQA/diamond.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "musique":
        with open("reasoning_rag/datasets/MUSIQUE/dev_sub.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "fanoutqa":
        with open("reasoning_rag/datasets/FANOUTQA/dev.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "multihopqa":
        with open("reasoning_rag/datasets/MULTIHOPRAG/test.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "cwq":
        with open("reasoning_rag/datasets/CWQ/dev.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "hotpotqa":
        with open("reasoning_rag/datasets/HOTPOTQA/dev_sub.json", "r") as f:
            dataset = json.load(f)
    elif args.dataset_name == "frames":
        with open("reasoning_rag/datasets/FRAMES/test.json", "r") as f:
            dataset = json.load(f)
    else:
        dataset = load_dataset(args.dataset_name)[args.dataset_split]
        dataset = [item for item in dataset]
    if args.start_idx is not None and args.end_idx is not None:
        dataset = dataset[args.start_idx:args.end_idx]
    
    # Use Generator for LLM and tokenizer
    generator = Generator(model_id_or_path=args.model_id_or_path)
    
    def get_prompt(item):
        prompt = PROMPT
        prompt += "\nYou should provide your final answer in the format \\boxed{{YOUR_ANSWER}}."
        prompt += "\n\nPlease answer the question: " + item["question"]
        prompt = generator.apply_chat_template(prompt, enable_thinking=True)
        return prompt
    
    # Get prompts
    for item in dataset:
        if "question" not in item:
            if "Question" in item:
                item["question"] = item["Question"]
            elif "query" in item:
                item["question"] = item["query"]
        item["finished"] = False
        item["prompt"] = get_prompt(item)
        item["output"] = ""
        item["output_history"] = []
        item["search_count"] = []
        item["executed_search_queries"] = []
        item["related_info_analysis"] = []
        item["relevant_info"] = {}
            
    # Set cache for search
    if os.path.exists(f"{args.output_dir}/search_cache.json"):
        with open(f"{args.output_dir}/search_cache.json", "r") as f:
            search_cache = json.load(f)
    else:
        search_cache = {}
    if os.path.exists(f"{args.output_dir}/url_cache.json"):
        with open(f"{args.output_dir}/url_cache.json", "r") as f:
            url_cache = json.load(f)
    else:
        url_cache = {}
    
    iteration = -1
    step_latency_records = []
    
    while True:
        iteration += 1
        step_search_count = 0
        step_llm_tokens = 0
        
        print(f"\n---------------- Iteration {iteration} ----------------\n")
        # Answer generation
        print("Start iteration...")
        print("Among", len(dataset), "items,", end=" ")
        
        # Select only ongoing items
        items_remained = [item for item in dataset if not item["finished"]]
        if len(items_remained) == 0:
            break
        print(len(items_remained), "are left...")
        
        # Initialize batch variables
        batch_relevant_info = []
        batch_original_questions = []
        batch_prev_reasonings = []
        batch_search_queries = []
        batch_documents = {}
        batch_sequences = []

        # Collect URLs to fetch across all sequences
        all_urls_to_fetch = set()
        url_snippets = {}
        
        # Prepare prompts & Inference
        llm_start = time.time()
        prompts = [item["prompt"] for item in items_remained]
        responses = generator.generate(
                        prompts, 
                        max_tokens=args.max_tokens, 
                        stop=[END_SEARCH_QUERY], 
                        apply_chat=False,
                        enable_thinking=True,
                    )
        step_llm_time = time.time() - llm_start
        
        search_start = time.time()
        
        # Store raw model outputs for each item
        for item, prompt, (response, num_tokens) in tqdm(zip(items_remained, prompts, responses)):
            # print("response: ", response)
            search_queries = extract_between(response, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
            # print("search_queries: ", search_queries)
            if search_queries:
                step_search_count += len(search_queries)
                if "</think>" not in response:
                    response = response.replace(f"\n\n{BEGIN_SEARCH_QUERY}", f"\n</think>\n\n{BEGIN_SEARCH_QUERY}")
            item["prompt"] += response
            item["output"] += response
            item["output_history"].append(response)
            step_llm_tokens += num_tokens

            if search_queries is None:    
                item["finished"] = True
                item["generated_answer"] = extract_answer(response)
                continue
            
            # If a search query is present and needs to be executed
            relevant_info = {}
            executed_this_round = []  
            
            for search_query in search_queries:
                if args.max_search_limit > 0 and len(item['executed_search_queries']) >= args.max_search_limit:
                    # Inform limit excess
                    limit_message = (
                        f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. "
                        "You are not allowed to search.\n"
                        f"{END_SEARCH_RESULT}\n"
                    )
                    item['prompt'] += limit_message
                    item['output'] += limit_message
                    print(f"Search limit reached for query: \"{search_query}\"")
                    break  
                
                # Add to executed_this_round
                executed_this_round.append(search_query)
                
                if search_query in set(item['executed_search_queries']):
                    repeat_msg = (
                        f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. "
                        "Please refer to previous results.\n"
                        f"{END_SEARCH_RESULT}\n"
                    )
                    relevant_info[search_query] = repeat_msg
                    continue
                
                # Execute search, use cache if available
                if search_query in search_cache:
                    results = search_cache[search_query]
                    relevant_info[search_query] = results[:args.top_k]
                
                else:
                    try:
                        results = jina_web_search(search_query, jina_api_key, args.top_k)
                        search_cache[search_query] = results
                        relevant_info[search_query] = results[:args.top_k]
                    except Exception as e:
                        print(f"Error during search query '{search_query}': {e}")
                        search_cache[search_query] = []
                        relevant_info[search_query] = []
                
            if len(relevant_info) == 0:
                item['finished'] = True
                item['generated_answer'] = extract_answer(response)
                continue
            
            item['relevant_info'].update(relevant_info)
            item['search_count'].append(len(executed_this_round))
            item['executed_search_queries'].extend(executed_this_round)
            
            # Extract URLs and snippets
            urls_to_fetch = []
            snippets = {}
            for _, results in relevant_info.items():
                for info in results:
                    if 'url' in info:
                        urls_to_fetch.append(info['url'])
                        if 'snippet' in info:
                            snippets[info['url']] = info['snippet']
                
            # Filter URLs that are not cached
            urls_to_fetch_filtered = [u for u in urls_to_fetch if u not in url_cache]
            cached_urls = [u for u in urls_to_fetch if u in url_cache]
            
            # Store info for all_urls_to_fetch and url_snippets
            for url in urls_to_fetch_filtered:
                all_urls_to_fetch.add(url)
                url_snippets[url] = snippets.get(url, "")
            
            all_reasoning_steps = item['output']
            all_reasoning_steps = all_reasoning_steps.replace('\n\n', '\n').split("\n")

            truncated_prev_reasoning = ""
            for i, step in enumerate(all_reasoning_steps):
                truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"

            prev_steps = truncated_prev_reasoning.split('\n\n')
            if len(prev_steps) <= 5:
                truncated_prev_reasoning = '\n\n'.join(prev_steps)
            else:
                truncated_prev_reasoning = ''
                for i, step in enumerate(prev_steps):
                    if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
                        truncated_prev_reasoning += step + '\n\n'
                    else:
                        if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
                            truncated_prev_reasoning += '...\n\n'
            truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')

            # Collect parameters for batch processing
            batch_relevant_info.append(relevant_info)
            batch_original_questions.append(item['question'])
            batch_prev_reasonings.append(truncated_prev_reasoning)
            batch_search_queries.append(executed_this_round)
            batch_sequences.append(item)
            
        # Batch fetch all URLs at once to optimize speed
        print(len(all_urls_to_fetch), "urls to fetch...")
        if all_urls_to_fetch:
            print(f"Fetching {len(all_urls_to_fetch)} URLs...")
            try:
                fetched_contents = fetch_page_content(
                    list(all_urls_to_fetch),
                    max_workers=16,
                    use_jina=args.use_jina,
                    jina_api_key=jina_api_key,
                    snippets=url_snippets  # Do not pass snippets when updating url_cache directly
                )
                print(f"Fetched {len(fetched_contents)} URLs successfully.")
            except Exception as e:
                print(f"Error during batch URL fetching: {e}")
                fetched_contents = {url: f"Error fetching URL: {e}" for url in all_urls_to_fetch}
            # Update cache with fetched contents
            for url, content in fetched_contents.items():
                url_cache[url] = content

        # After fetching, prepare formatted documents for batch processing
        for relevant_info, executed_queries in zip(batch_relevant_info, batch_search_queries): # [ [{query: doc_info}] for each search_query] for each item ] 
            formatted_documents = {}
            for sub_query in executed_queries:
                doc_str = ""
                sub_relevant_info = relevant_info[sub_query]
                if type(sub_relevant_info) == str:
                    formatted_documents[sub_query] = sub_relevant_info # Already Searched
                    continue
                for i, doc_info in enumerate(sub_relevant_info):
                    url = doc_info.get('url', "")
                    raw_context = url_cache.get(url, "")
                    snippet = doc_info["snippet"].replace('<b>', '').replace('</b>', '') if doc_info["snippet"] else None
                    doc_info['snippet'] = snippet
                    success, filtered_context = extract_snippet_with_context(raw_context, snippet, context_chars=args.max_doc_len)
                    context = filtered_context if success else raw_context[:args.max_doc_len*2]
                    doc_info['context'] = context
                    doc_str += f"**Web Page {i + 1}:**\n"
                    doc_str += json.dumps(doc_info, ensure_ascii=False, indent=2) + "\n"
                formatted_documents[sub_query] = doc_str
            batch_documents.update(formatted_documents)


        # After fetching, prepare for batch processing if there are any
        if batch_sequences:
            print(f"Batch processing {len(batch_sequences)} sequences with generate_webpage_to_reasonchain_batch...")
            webpage_analyses = generate_webpage_to_reasonchain_batch(
                args,
                generator,
                original_questions=batch_original_questions, # list of queries
                prev_reasonings=batch_prev_reasonings, # list of reasonings
                search_queries=batch_search_queries, # [ [ search queries ] for each question ]
                documents=batch_documents, # [ [ [ formatted_doc_string ] for each seach-query ] for each question ]
                dataset_name=args.dataset_name,
                max_tokens=args.max_tokens,
            )
            print("Batch generation completed, assigning outputs to sequences...")

            for seq, queries in zip(batch_sequences, batch_search_queries):
                combined_analysis = ""
                for query in queries:
                    analysis = webpage_analyses[query]
                    combined_analysis += f"{query}: {analysis}\n"
                    seq["related_info_analysis"].append(analysis)
                
                append_text = f"\n\n{BEGIN_SEARCH_RESULT}\n{combined_analysis.strip()}\n{END_SEARCH_RESULT}\n\n<think>\n"
                seq['prompt'] += append_text
                seq['output'] += append_text                    
                # print("Combined analysis:", append_text)
            
        step_search_time = time.time() - search_start
        step_latency_records.append({
            "iteration": iteration,
            "step_llm_tokens": step_llm_tokens,
            "step_llm_time": step_llm_time,
            "step_search_time": step_search_time,
            "step_search_count": step_search_count,
            "items_remained": len(items_remained),
        })
        
        items_remained = [item for item in items_remained if not item['finished']]
        if len(items_remained) > 0 and iteration < args.max_iteration:
            continue 
        
        elif len(items_remained) > 0 and iteration == args.max_iteration:    
            items_remained = [item for item in dataset if not item["finished"] or "generated_answer" not in item or item["generated_answer"] == ""]
            def map_item(item):
                prompt = item["prompt"].strip() 
                output = item["output"].strip()
                while True: 
                    if prompt.endswith(END_SEARCH_QUERY):
                        prompt = prompt.strip(END_SEARCH_QUERY).strip()
                        output = output.strip(END_SEARCH_QUERY).strip()
                    elif prompt.endswith(BEGIN_SEARCH_QUERY):
                        prompt = prompt.strip(BEGIN_SEARCH_QUERY).strip()
                        output = output.strip(BEGIN_SEARCH_QUERY).strip()
                    elif prompt.endswith(END_SEARCH_RESULT):
                        prompt = prompt.strip(END_SEARCH_RESULT).strip()
                        output = output.strip(END_SEARCH_RESULT).strip()
                    elif prompt.endswith(BEGIN_SEARCH_RESULT):
                        prompt = prompt.strip(BEGIN_SEARCH_RESULT).strip()
                        output = output.strip(BEGIN_SEARCH_RESULT).strip()
                    else:
                        break
                prompt = prompt.strip()
                output = output.strip()
                if not prompt.endswith("</think>"):
                    prompt += "\n</think>\n\n**Final Answer:**\n\\boxed{"
                    output += "\n</think>\n\n**Final Answer:**\n\\boxed{"
                item["prompt"] = prompt 
                item["output"] = output
                item["output_history"].append("\n</think>\n\n**Final Answer:**\n\\boxed{")
                return item
            items_remained = list(map(map_item, items_remained))
            prompts = [item["prompt"] for item in items_remained]
            responses = generator.generate(
                            prompts, 
                            max_tokens=args.max_tokens, 
                            apply_chat=False,
                            enable_thinking=False,
                        )
            for item, (response, num_tokens) in tqdm(zip(items_remained, responses)):
                item["prompt"] += response
                item["output"] += response
                item["output_history"][-1] += response
                item["finished"] = True
                item["generated_answer"] = response.strip("}")
            break 
        
        else:
            break
    
    
    latency_path = output_path.replace(".json", "_step_latency_records.json")
    print(f"Saving step latency records... {latency_path}")
    with open(latency_path, "w") as f:
        json.dump(step_latency_records, f, indent=2)
    url_cache_path = output_path.replace(".json", "_url_cache.json")
    print(f"Saving url cache... {url_cache_path}")
    with open(url_cache_path, "w") as f:
        json.dump(url_cache, f, indent=2)
    search_cache_path = output_path.replace(".json", "_search_cache.json")
    print(f"Saving search cache... {search_cache_path}")
    with open(search_cache_path, "w") as f:
        json.dump(search_cache, f, indent=2)
    print(f"Saving results... {output_path}...")
    with open(output_path, "w") as f:
        json.dump(dataset, f, indent=2)
    print(f"******************************\nOutput file: {output_path}\n******************************")
    print("Done!")
    
if __name__ == "__main__":
    main()
    
    
