"""
Search tool implementation for simulating internet searches
"""

import asyncio
import time
import random
from typing import Dict, List, Any, Optional
import os

from langchain_text_splitters import RecursiveCharacterTextSplitter
import requests

from meta_researcher.tool.tools.search_engine import Retriever, SearchEngine
from meta_researcher.tool.tools.search_engine.base_search import SearchConfig
from meta_researcher.tool.tools.search_engine.website_crawler.simple_crawler import SimpleContentFetch
from meta_researcher.tool.base import BaseTool

# from txtai.embeddings import Embeddings
import json

SEARCH_ENGINES = os.getenv('SEARCH_ENGINES') #GoogleSerperSearch
SEARCH_BASE_URL = os.getenv('SEARCH_BASE_URL')
OPTIONAL_SEARCH_ENGINE = os.getenv('OPTIONAL_SEARCH_ENGINE')
MAX_QUERIES_PER_STEP = 5

class WebSearchTool(BaseTool):
    name = "web_search"
    description = """Web Search Tool. It can provide relevant web - page information by inputting search keywords. Note that the web search tool needs to be called to obtain more knowledge in the following situations:
    1. The question asked exceeds your knowledge reserve and scope, and you do not have the specific knowledge to accurately answer the user's question.
    2. When it comes to the latest data, dynamic information, knowledge beyond the cut - off time of your training data, or real - time updated knowledge, the search engine needs to be called to obtain relevant information.
    3. More detailed and novel knowledge can be obtained through the Internet, such as online buzzwords, real - time information, product information, etc.
    4. Fabrication is prohibited. When encountering unfamiliar nouns, things, or concepts, you need to search the Internet to supplement your knowledge.
    """
    parameters = {
        "type": "object",
        "properties": {
            "queries": {
                "type": "array",
                "items": {
                    "type": "string"
                },
                "description": f"Always prioritize the use of a single search query. Add another query only when the original question covers multiple aspects or elements and a single search request is absolutely insufficient. Each query should focus on a specific aspect of the original question, minimize the mutual - information between each query, and the queries should be in a parallel relationship. You can call the web_search tool multiple times. The maximum number of search requests is {MAX_QUERIES_PER_STEP}.",
            },
        },
        "required": ["queries"]
    }

    def __init__(self):
        super().__init__()
        print("[DEBUG] EMBEDDINGS LOADING")
        
        # Init the search engine
        if SEARCH_ENGINES == "GoogleBingOptionalSearch":
            self.retriever = Retriever(engine=SEARCH_ENGINES, base_url=SEARCH_BASE_URL, optional_search_engine=OPTIONAL_SEARCH_ENGINE, search_method="SimpleLLMSearchAgent")
        else:
            self.retriever = Retriever(engine=SEARCH_ENGINES)
        self.model_rerank_name = "bge-reranker-v2-m3"
        self.model_rerank_api_key = "abc123"
        self.model_rerank_base_url = "http://localhost:59000/v1"

        print("[DEBUG] EMBEDDINGS LOADING END")

    
    def execute(self, args: Dict) -> str:
        """
        Execute search query
        
        Args:
            args: Tool parameters, containing:
                - "query": search query string
                - "limit": optional int to limit number of results
            
        Returns:
            Formatted search results
        """
        pass
    
    def batch_execute(self, args_list: List[Dict]) -> List[str]:
        try:
            queries = [x["queries"] for x in args_list]
            loop = asyncio.get_event_loop()
            if SEARCH_ENGINES == "GoogleBingOptionalSearch":
                webpages = loop.run_until_complete(self.retriever.retrieve(queries))
                results_str = [self._format_results(webpages[i]) for i in range(len(webpages))]
            else:
                webpages = loop.run_until_complete(self.retriever.retrieve(queries))
                webpages_ = loop.run_until_complete(self.retriever.batch_fetch_web_contents(webpages))
                results_str = [self._format_results(webpages_[i]) for i in range(len(webpages_))]
            return [{"content": str(content), "success": True} for content in results_str]
        except Exception as e:
            return [{"content": str(e), "success": False} for _ in args_list]

    def _format_results(self, results: List) -> str:
        """
        Format search results for better readability
        
        Args:
            results: List of search result List
            
        Returns:
            Formatted results as a string
        """
        results_list = []
        
        for i, result in enumerate(results):
            if result["content"]:
                similar_content = self.async_similarity_search(result['query'], result["content"])
                content = "".join(similar_content)
                information = f"Doc{i}: Title: {result['title']}, Link: {result['url']}, Content: {content}\n\n"
            else:
                information = f"Doc{i}: Title: {result['title']}, Link: {result['url']}, Content: {result['summ']}\n\n"
            results_list.append(information)
        
        return json.dumps({"results": results_list}, ensure_ascii=False)
    
    def async_similarity_search(self, query:str, document:str, k=1, chunk_size=512, chunk_overlap=10):
        splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, 
                                                  chunk_overlap=chunk_overlap, 
                                                  separators=["\n\n", "\n", " ", ",", ";", ".", "!", "?", "—", "-"],  
                                                  keep_separator=True)
        split_text = splitter.split_text(document)
        if len(split_text) <= k:
            return split_text
        
        rerank_result = self.process_query(query, split_text)
        
        if rerank_result:
            sorted_indices = sorted(rerank_result.keys(), key=lambda x: rerank_result[x], reverse=True)
            top_k_indices = sorted_indices[:k]
            search_results = [split_text[i] for i in top_k_indices]
            return search_results
        else:
            return split_text[-1]
    
    def process_query(self, query, split_text):
        payload = {
            "model": self.model_rerank_name,
            "query": query,
            "documents": [p for p in split_text],
        }
        headers = {"Authorization": f"Bearer {self.model_rerank_api_key}"}
        try:
            response = requests.post(
                url=self.model_rerank_base_url + "/rerank",
                headers=headers,
                json=payload
            )
            
            if response.status_code != 200:
                raise Exception(f"Error: Received status code {response.status_code}")
            
            response_data = response.json()
            rerank_result = {r["index"]: r["relevance_score"] for r in response_data["results"]}
            return rerank_result
        except:
            return None
    
    def calculate_reward(self, args: Dict, result: str) -> float:
        """
        Calculate reward for search action
        
        Args:
            args: Tool parameters
            result: Tool execution result
            
        Returns:
            Reward value
        """
        # valid tool call
        if "results" in result:
            return 0.1
        else:
            return 0.0