import os
import sys
import json
import time  
import httpx
import hydra
import logging
import requests
from typing import Optional
from urllib.parse import urlparse
from hydra import initialize, compose
from cachetools import TTLCache, cached
from omegaconf import DictConfig, OmegaConf

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
project_root = "./futuremind"
sys.path.insert(0, project_root)
logger.info(f"Added to PYTHONPATH: {project_root}")

try:
    from futuremind.tool.tools.search.google_utils import SearchSnippet, TokenTracker
except ImportError as e:
    logger.error(f"Import failed: {e}")
    raise

class SearchHandler:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.search_count = 0

    def bochaai_search(self, query):
        Retriever_cfg = self.cfg.Retriever
        payload = json.dumps({
            "query": query,
            "summary": Retriever_cfg.summary,
            "count": Retriever_cfg.topk,
            "page": Retriever_cfg.page
        })

        headers = {
            'Authorization': f'Bearer {Retriever_cfg.apikey}',
            'Content-Type': 'application/json'
        }

        start_time = time.time()
        response = requests.post(Retriever_cfg.url, headers=headers, data=payload)
        end_time = time.time()
        elapsed_time = end_time - start_time

        resp = response.json()
        results = resp["data"]["webPages"]["value"]

        information = ""
        for i in range(min(len(results), int(Retriever_cfg.topk))):
            name = f"(Title: {results[i]['name']})"
            info_date = f"DateLastCrawled: {results[i]['dateLastCrawled']}"
            summary = results[i]['summary']
            information += f"Doc {i}: {name} {info_date} Summary: {summary}\n"

        return information

    def _call_google_api(self, query: str):

        import random
        import requests
        import json
        import time

        api_key = "..."
        base_urls = [ "..."]

        topk=10
        need_search=1
        context_size=32000
        return_chat_prompt=False
        return_search_results=True
        return_full_search_content=True
        return_search_results_prompt=True
        
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_key}",
        }
        
        data = {
            "query": query,
            "context_size": context_size,
            "need_search": need_search,
            "topk": topk,
            "search_engine": "GoogleSerperSearch",
            "return_full_search_content": return_full_search_content,
            "timeout": 100,
        }

        for base_url in random.sample(base_urls, k=len(base_urls)):
            url = f"{base_url}/generate_prompt"
            try:
                resp = requests.post(url, headers=headers, json=data, timeout=30)
                resp.raise_for_status()
                return resp.json()
            except requests.RequestException as e:
                print(f"Endpoint {base_url} failed: {e}; trying next...")
                time.sleep(2)

        raise RuntimeError("All endpoints failed") 
        
    def _parse_response(self, response) -> list[SearchSnippet]:
        webpages = {w["id"]: w for w in response.get("webPages", {}).get("value", [])}
        results: list[SearchSnippet] = []
        for item in response.get("rankingResponse", {}).get("mainline", {}).get("items", []):
            if item["answerType"] == "WebPages":
                webpage = webpages.get(item["value"]["id"])
                if webpage:
                    url_info = urlparse(webpage["url"])
                    site_name = webpage.get("siteName", "") or url_info.netloc
                    results.append(
                        SearchSnippet(
                            url=webpage["url"],
                            description=webpage.get("snippet", ""),
                            title=webpage.get("name", ""),
                        )
                    )
        return results

    @cached(cache=TTLCache(maxsize=100, ttl=600))
    def google_search(
        self, query: str, tracker: Optional[TokenTracker] = None, topk: Optional[int] = None, retry_time: float = 0
    ) -> list[SearchSnippet]:
        self.search_count += 1
        max_retries = 30

        for attempt in range(1, max_retries + 1):
            try:
                response = self._call_google_api(query)
                search_results = response.get("search_results", [])
                if not search_results:
                    raise ValueError("GoogleSerperSearch API returned empty items.")

                return search_results[:topk] if topk is not None else search_results

            except Exception as e:
                if attempt == max_retries:
                    logger.error(
                        "GoogleSerperSearch abandoned after %d attempts: %s",
                        max_retries, e
                    )
                    return []  
                else:
                    logger.warning(
                        "GoogleSerperSearch search failed (%s), retrying in 2s... (attempt %d/%d)",
                        e, attempt, max_retries
                    )
                    time.sleep(2)

    def google_search_url_formatted(
        self, query: str, tracker: Optional[TokenTracker] = None, topk: Optional[int] = None
    ) -> str:
        snippets = self.google_search(query, tracker, topk)
        information = []
        for i, snippet in enumerate(snippets):
            r = snippet
            new_snippet = { "doc_index": i,
                            "url": r.get("url", ""),
                            "title": r.get("title", ""),
                            "summary": r.get("summ", ""),
                            "content": r.get("content", ""),
                            "site_name":r.get("site_name", ""),
                            "published_date": r.get("published_date", ""), 
            }
            information.append(new_snippet)
        return information

    def process_url_prompt(self, args_infer):
        res = {}
        idx, one_query, topk = args_infer
        res["query"] = one_query
        res["information"] = self.google_search_url_formatted(one_query, topk=topk)
        res["idx"]= idx 
        if res["information"]:
            return res
        return None

    def batch_google_search_url_formatted(self, query_list: list, topk: Optional[int] = None) -> list:
        from tqdm.contrib.concurrent import process_map
        args_infer = [(idx, one_query, topk) for idx, one_query in enumerate(query_list)]
        max_workers = 60
        chunksize =8
        results = process_map(self.process_url_prompt, args_infer, max_workers = max_workers, chunksize = chunksize)
        results = [result for result in results if result is not None]
        results.sort(key=lambda x: x['idx'])
        
        return results

    def google_search_formatted(
        self, query: str, tracker: Optional[TokenTracker] = None, topk: Optional[int] = None
    ) -> str:
        snippets = self.google_search(query, tracker, topk)

        information = []
        for i, snippet in enumerate(snippets):
            entry = {
                "doc_index": i,
                "title": snippet.get('title', ''),
                "url": snippet.get('url', ''),
                "summary": snippet.get('summ', ''),
                "content": snippet.get('content', '')
            }
            information.append(entry)
        return information
       
    def process_prompt(self, args_infer):
        res = {}
        idx, one_query, topk = args_infer
        res["query"] = one_query
        res["information"] = self.google_search_formatted(one_query, topk=topk)
        res["idx"]= idx 
        if res["information"]:
            return res
        return None

    def batch_google_search_formatted(self, query_list: list, topk: Optional[int] = None) -> list:
        from tqdm.contrib.concurrent import process_map
        args_infer = [(idx, one_query, topk) for idx, one_query in enumerate(query_list)]
        max_workers = 30
        chunksize =8
        results = process_map(self.process_prompt, args_infer, max_workers = max_workers, chunksize = chunksize)
        results = [result for result in results if result is not None]
        results.sort(key=lambda x: x['idx'])
        
        return results


def get_search_handler(config_path: str = "../search/config", config_name: str = "eval_search") -> SearchHandler:
    with initialize(config_path=config_path, version_base="1.1"):
        cfg = compose(config_name=config_name)
    return SearchHandler(cfg)

if __name__ == "__main__":
    handler = get_search_handler()
    topk = 5
    query_list = [ "SU7 Ultra"]
    batch_formatted_results = handler.batch_google_search_formatted(query_list, topk=topk)
    from rich import print
    print(f"总共调用 google_search: {handler.search_count} 次")
    print(f"共返回 {len(batch_formatted_results)} 条数据。")
    print(batch_formatted_results)
