import re
import os
import json 
import torch
import argparse
from tqdm import tqdm
from datasets import load_dataset
from openai import OpenAI
from prompts import *
from bing_search import (
    jina_web_search,
    extract_snippet_with_context,
    fetch_page_content
)
from typing import List
from utils import Generator  # Import Generator

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://0.0.0.0:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
jina_api_key = os.environ["JINA_API_KEY"]

# Set search tokens
BEGIN_SEARCH_QUERY = "<|begin_search_queries|>"
END_SEARCH_QUERY = "<|end_search_queries|>"
BEGIN_SEARCH_RESULT = "<|begin_search_results|>"
END_SEARCH_RESULT = "<|end_search_results|>"
MAX_SEARCH_LIMIT = 10
MAX_ITERATION = 5

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", "-d", default="NQParaQ") 
    parser.add_argument("--dataset_split", default="train")
    parser.add_argument("--model_name", "-m", default="Qwen/Qwen3-32B")
    parser.add_argument("--use_jina", type=bool, default=True)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--max_doc_len", type=int, default=512)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_k", type=int, default=3, help="Retrieve the top k documents")
    parser.add_argument("--output_dir", "-o", type=str, default="../results/inference")
    parser.add_argument("--start_idx", "-s", type=int)
    parser.add_argument("--end_idx", "-e", type=int)
    args = parser.parse_args()
    return args


# ---------------------- Batch Generation Function ----------------------
    
def generate_webpage_to_reasonchain_batch(
        args,
        generator,  # Use Generator instance
        original_questions: List[str],
        prev_reasonings: List[str],
        search_queries: List[List[str]],  # list of queries per sequence
        documents: List[List[str]],       # list of list of doc strings per query per sequence
        dataset_name: str,
        max_tokens: int = 512,
        coherent: bool = False,
    ) -> List[List[str]]:  # returns list of list of outputs per query per sequence

    # Flatten all prompts: one prompt per query
    all_user_prompts = []
    seq_query_counts = []  # remember how many queries per sequence
    for prev_reasoning, seq_queries, seq_docs in zip(prev_reasonings, search_queries, documents):
        seq_query_counts.append(len(seq_queries))
        for query, doc in zip(seq_queries, seq_docs):
            prompt = prompt_for_webpage_to_reasonchain_instruction.format(
                prev_reasoning=prev_reasoning, 
                search_query=query, 
                document=doc
            )
            # print("\n======= reason chain prompt =======\n", prompt)
            all_user_prompts.append(prompt)

    # Prepare chat prompts for all queries flattened
    # Use generator's chat template
    # for prompt in all_user_prompts:
    #     print("prompt:", prompt[-100:])
    raw_outputs = generator.generate(all_user_prompts, max_tokens=max_tokens, apply_chat=True, enable_thinking=False, repetition_penalty=1.05)
    # print("--------------------------------")
    # print("Prompts:", len(all_user_prompts))
    # print("Raw outputs:", raw_outputs, "\n\n")
    extracted_infos = [extract_answer(raw, mode='infogen') for raw in raw_outputs]
    # print("Extracted infos:", extracted_infos, "\n\n")

    # Group outputs back by sequence
    grouped_outputs = []
    idx = 0
    for count in seq_query_counts:
        grouped_outputs.append(extracted_infos[idx:idx+count])
        idx += count

    return grouped_outputs


def extract_answer(output, mode='gen'):
    extracted_text = ''
    if mode == 'codegen':
        # Extract the code between ```python and ```
        pattern = r'```python\s*(.*?)\s*```'
        matches = re.findall(pattern, output, re.DOTALL | re.IGNORECASE)
        if matches:
            extracted_text = matches[-1].strip()  # Take the last match
    elif mode == 'infogen':
        # Extract content after **Final Information** or **Modified Reasoning Steps**
        pattern_info = "**Final Information**"
        extracted_text = output.replace(pattern_info, "").strip()
        if extracted_text == "":
            extracted_text = "No helpful information found."
    else:
        # Existing extraction logic for 'gen' and 'choose' modes
        pattern = r'\\boxed\{(.*)\}'
        matches = re.findall(pattern, output)
        if matches:
            extracted_text = matches[-1]  # Take the last match
            if mode in ['choose', 'qa']:
                # Handle 'choose' mode
                inner_pattern = r'\\text\{(.*)\}'
                inner_matches = re.findall(inner_pattern, extracted_text)
                if inner_matches:
                    extracted_text = inner_matches[-1]  # Take the last match
                extracted_text = extracted_text.strip("()")
    return extracted_text


# Function to extract text between two tags
def extract_between(text: str) -> List[str]:
    
    # Set pattern 
    pattern_begin = "<\|begin_search_queries\|>"
    pattern = "<\|begin_search_queries\|>([\s\S]*?)<\|end_search_queries\|>"
    
    if len(re.findall(pattern_begin, text)) > 1:
        return
    output = re.findall(pattern, text)
    if len(output) == 0:
        return
    
    # Split output
    output = output[0].strip().strip("\n")
    output = [i.strip().strip(";") for i in output.split("\n")]
    
    # Return output when 
    output = [o for o in output if len(o) > 0]
    return output


def replace_recent_steps(origin_str, replace_str):
    """
    Replaces specific steps in the original reasoning steps with new steps.
    If a replacement step contains "DELETE THIS STEP", that step is removed.

    Parameters:
    - origin_str (str): The original reasoning steps.
    - replace_str (str): The steps to replace or delete.

    Returns:
    - str: The updated reasoning steps after applying replacements.
    """

    def parse_steps(text):
        """
        Parses the reasoning steps from a given text.

        Parameters:
        - text (str): The text containing reasoning steps.

        Returns:
        - dict: A dictionary mapping step numbers to their content.
        """
        step_pattern = re.compile(r"Step\s+(\d+):\s*")
        steps = {}
        current_step_num = None
        current_content = []

        for line in text.splitlines():
            step_match = step_pattern.match(line)
            if step_match:
                # If there's an ongoing step, save its content
                if current_step_num is not None:
                    steps[current_step_num] = "\n".join(current_content).strip()
                current_step_num = int(step_match.group(1))
                content = line[step_match.end():].strip()
                current_content = [content] if content else []
            else:
                if current_step_num is not None:
                    current_content.append(line)
        
        # Save the last step if any
        if current_step_num is not None:
            steps[current_step_num] = "\n".join(current_content).strip()
        
        return steps

    # Parse the original and replacement steps
    origin_steps = parse_steps(origin_str)
    replace_steps = parse_steps(replace_str)

    # Apply replacements
    for step_num, content in replace_steps.items():
        if "DELETE THIS STEP" in content:
            # Remove the step if it exists
            if step_num in origin_steps:
                del origin_steps[step_num]
        else:
            # Replace or add the step
            origin_steps[step_num] = content

    # Sort the steps by step number
    sorted_steps = sorted(origin_steps.items())

    # Reconstruct the reasoning steps as a single string
    new_reasoning_steps = "\n\n".join([f"{content}" for num, content in sorted_steps])

    return new_reasoning_steps


def main():
    args = get_args()
    print("Model:", args.model_name)
    
    # Load dataset
    dataset = load_dataset(args.dataset_name)["train"]
    dataset = [item for item in dataset]
    if args.start_idx is not None and args.end_idx is not None:
        dataset = dataset[args.start_idx:args.end_idx]
    
    # Use Generator for LLM and tokenizer
    generator = Generator(model=args.model_name)
    
    def get_prompt(item):
        prompt =prompt_for_response_generatinon.format(
                question=item["final_question"],
            )
        prompt = generator.apply_chat_template(prompt, enable_thinking=False)
        if "<think>" not in prompt[-10:]:
            prompt = prompt.strip() + f"\n\n<think>\n"
        return prompt
    
    # Get prompts
    for item in dataset:
        item["finished"] = False
        item["prompt"] = get_prompt(item)
        item["output"] = ""
        item["output_history"] = []
        item["search_count"] = []
        item["executed_search_queries"] = []
        item["related_info_analysis"] = []
    
    # Set cache for search
    # if os.path.exists(f"{args.output_dir}/search_cache.json"):
    #     with open(f"{args.output_dir}/search_cache.json", "r") as f:
    #         search_cache = json.load(f)
    # else:
    #     search_cache = {}
    # if os.path.exists(f"{args.output_dir}/url_cache.json"):
    #     with open(f"{args.output_dir}/url_cache.json", "r") as f:
    #         url_cache = json.load(f)
    # else:
    #     url_cache = {}
    search_cache = {}
    url_cache = {}
    
    iteration = -1
    while True:
        iteration += 1
        print(f"\n---------------- Iteration {iteration} ----------------\n")
        # Answer generation
        print("Start iteration...")
        print("Among", len(dataset), "items,", end=" ")
        
        # Select only ongoing items
        items_remained = [item for item in dataset if not item["finished"]]
        if len(items_remained) == 0:
            break
        print(len(items_remained), "are left...")
        
        # Initialize batch variables
        batch_relevant_info = []
        batch_original_questions = []
        batch_prev_reasonings = []
        batch_search_queries = []
        batch_documents = []
        batch_sequences = []

        # Collect URLs to fetch across all sequences
        all_urls_to_fetch = set()
        url_snippets = {}
        
        # Prepare prompts & Inference
        prompts = [item["prompt"] for item in items_remained]
        responses = generator.generate(
                        prompts, 
                        max_tokens=args.max_tokens, 
                        stop=[END_SEARCH_QUERY], 
                        apply_chat=False,
                        enable_thinking=True,
                    )
        # Store raw model outputs for each item
        for item, prompt, response in tqdm(zip(items_remained, prompts, responses)):
            search_queries = extract_between(response)
            if search_queries and "</think>" not in response:
                response = response.replace("\n\n<|begin_search_queries|>", "\n</think>\n\n<|begin_search_queries|>")
            item["prompt"] += response
            item["output"] += response
            item["output_history"].append(response)
            if search_queries is None:    
                item["finished"] = True
                item["generated_answer"] = extract_answer(response)
                continue
            
            # If a search query is present and needs to be executed
            relevant_info = []
            executed_this_round = []  
            
            for search_query in search_queries:
                if sum(item['search_count']) >= MAX_SEARCH_LIMIT:
                    # Inform limit excess
                    limit_message = (
                        f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. "
                        "You are not allowed to search.\n"
                        f"{END_SEARCH_RESULT}\n"
                    )
                    item['prompt'] += limit_message
                    item['output'] += limit_message
                    print(f"Search limit reached for query: \"{search_query}\"")
                    break  
                
                elif search_query in set(item['executed_search_queries']):
                    repeat_msg = (
                        f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. "
                        "Please refer to previous results.\n"
                        f"{END_SEARCH_RESULT}\n"
                    )
                    item['prompt'] += repeat_msg
                    item['output'] += repeat_msg
                    print(f"Repeated search for query: \"{search_query}\"")
                    continue
                
                # Execute search, use cache if available
                elif search_query in search_cache:
                    results = search_cache[search_query]
                    print(f"Using cached search results for query: \"{search_query}\"")
                
                else:
                    try:
                        results = jina_web_search(search_query, jina_api_key, args.top_k)
                        search_cache[search_query] = results
                        relevant_info.append(results[:args.top_k])
                    except Exception as e:
                        print(f"Error during search query '{search_query}': {e}")
                        search_cache[search_query] = []
                        relevant_info.append([])
                
                # Record the result
                item['executed_search_queries'].append(search_query)
                executed_this_round.append(search_query)
            item['search_count'].append(len(search_queries))
            
            # Extract relevant information from Jina search results
            item['relevant_info'] = relevant_info
        
            # Extract URLs and snippets
            urls_to_fetch = []
            for results_for_query in relevant_info:
                urls_to_fetch.extend([it['url'] for it in results_for_query if 'url' in it])
            snippets = {}
            for results_for_query in relevant_info:
                for info in results_for_query:
                    if 'url' in info and 'snippet' in info:
                        snippets[info['url']] = info['snippet']
                        
            # Filter URLs that are not cached
            urls_to_fetch_filtered = [u for u in urls_to_fetch if u not in url_cache]
            cached_urls = [u for u in urls_to_fetch if u in url_cache]

            # Store info for all_urls_to_fetch and url_snippets
            for url in urls_to_fetch_filtered:
                all_urls_to_fetch.add(url)
                url_snippets[url] = snippets.get(url, "")
            
            all_reasoning_steps = item['output']
            all_reasoning_steps = all_reasoning_steps.replace('\n\n', '\n').split("\n")

            truncated_prev_reasoning = ""
            for i, step in enumerate(all_reasoning_steps):
                truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"

            prev_steps = truncated_prev_reasoning.split('\n\n')
            if len(prev_steps) <= 5:
                truncated_prev_reasoning = '\n\n'.join(prev_steps)
            else:
                truncated_prev_reasoning = ''
                for i, step in enumerate(prev_steps):
                    if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
                        truncated_prev_reasoning += step + '\n\n'
                    else:
                        if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
                            truncated_prev_reasoning += '...\n\n'
            truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')

            # Collect parameters for batch processing
            item['relevant_info'] = relevant_info
            batch_relevant_info.append(relevant_info)
            batch_original_questions.append(item['question'])
            batch_prev_reasonings.append(truncated_prev_reasoning)
            batch_search_queries.append(executed_this_round)   # ← 중요: 리스트로 저장
            batch_sequences.append(item)
            
        # Batch fetch all URLs at once to optimize speed
        print(len(all_urls_to_fetch), "urls to fetch...")
        if all_urls_to_fetch:
            print(f"Fetching {len(all_urls_to_fetch)} URLs...")
            try:
                fetched_contents = fetch_page_content(
                    list(all_urls_to_fetch),
                    use_jina=args.use_jina,
                    jina_api_key=jina_api_key,
                    # snippets=url_snippets  # Do not pass snippets when updating url_cache directly
                )
                print(f"Fetched {len(fetched_contents)} URLs successfully.")
            except Exception as e:
                print(f"Error during batch URL fetching: {e}")
                fetched_contents = {url: f"Error fetching URL: {e}" for url in all_urls_to_fetch}
            # Update cache with fetched contents
            for url, content in fetched_contents.items():
                url_cache[url] = content

        # After fetching, prepare formatted documents for batch processing
        for relevant_info in batch_relevant_info: # [ [ [ {doc_info} for searched documents] for each query ] for each document ]
            list_of_formatted_documents = []
            for sub_relevant_info in relevant_info: 
                formatted_documents = ""
                for i, doc_info in enumerate(sub_relevant_info):
                    url = doc_info.get('url', "")
                    raw_context = url_cache.get(url, "")
                    snippet = doc_info["snippet"].replace('<b>', '').replace('</b>', '') if doc_info["snippet"] else None
                    doc_info['snippet'] = snippet
                        
                    success, filtered_context = extract_snippet_with_context(raw_context, snippet, context_chars=args.max_doc_len)
                    context = filtered_context if success else raw_context[:args.max_doc_len*2]
                    doc_info['context'] = context
                    
                    formatted_documents += f"**Web Page {i + 1}:**\n"
                    formatted_documents += json.dumps(doc_info, ensure_ascii=False, indent=2) + "\n"
                list_of_formatted_documents.append(formatted_documents)
            batch_documents.append(list_of_formatted_documents)

        # After fetching, prepare for batch processing if there are any
        if batch_sequences:
            print(f"Batch processing {len(batch_sequences)} sequences with generate_webpage_to_reasonchain_batch...")
            webpage_analyses = generate_webpage_to_reasonchain_batch(
                args,
                generator,
                original_questions=batch_original_questions, # list of queries
                prev_reasonings=batch_prev_reasonings, # list of reasonings
                search_queries=batch_search_queries, # [ [ search queries ] for each question ]
                documents=batch_documents, # [ [ [ formatted_doc_string ] for each seach-query ] for each question ]
                dataset_name=args.dataset_name,
            )
            print("Batch generation completed, assigning outputs to sequences...")

            for seq, queries, analyses in zip(batch_sequences, batch_search_queries, webpage_analyses):
                combined_analysis = ""
                for query, analysis in zip(queries, analyses):
                    combined_analysis += f"{query}: {analysis}\n"
                    seq["related_info_analysis"].append(analysis)
                
                append_text = f"\n\n{BEGIN_SEARCH_RESULT}\n{combined_analysis.strip()}\n{END_SEARCH_RESULT}\n\n<think>\n"
                seq['prompt'] += append_text
                seq['output'] += append_text                    
                # print("Combined analysis:", append_text)

            items_remained = [item for item in items_remained if not item['finished']]
            if len(items_remained) == 0 or iteration == MAX_ITERATION:
                break
    
    dataset_name = args.dataset_name.split("/")[-1]
    model_name = args.model_name.split("/")[-1]
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    output_path = f"{args.output_dir}/{dataset_name}_{model_name}_from_{args.start_idx}_to_{args.end_idx}.json"
    # print(f"Saving search cache... {args.output_dir}/search_cache.json")
    # with open(f"{args.output_dir}/search_cache.json", "w") as f:
    #     json.dump(search_cache, f, indent=2)
    # print(f"Saving url cache... {args.output_dir}/url_cache.json")
    # with open(f"{args.output_dir}/url_cache.json", "w") as f:
    #     json.dump(url_cache, f, indent=2)
    print(f"Saving results... {output_path}...")
    with open(output_path, "w") as f:
        json.dump(dataset, f, indent=2)
    print("Done!")
    
if __name__ == "__main__":
    main()
    
    
