import asyncio
from meta_researcher.tool.tools.search_engine.log import logger
import httpx
import aiohttp
from asyncache import cached as acached
from cachetools import TTLCache

from meta_researcher.tool.tools.search_engine.base_search import BaseSearch, SearchConfig


class GoogleSerperSearch(BaseSearch):
    result_key_for_type = {
        "news": "news",
        "places": "places",
        "images": "images",
        "search": "organic",
    }

    def __init__(self, args: SearchConfig):
        super().__init__(args)
        self.topk = 10
        self.search_type = "search"
        self.name = "GoogleSerperSearch"

    @acached(cache=TTLCache(maxsize=100, ttl=600))
    async def search(self, query: str) -> list[dict]:
        for attempt in range(self.args.max_retry):
            try:
                logger.info(f"Begin search web for {query} with GoogleSerperSearch")
                response = await self._search(query)
                return self._parse_results(response)
            except Exception as e:
                logger.warning(str(e))
                await asyncio.sleep(1)
        raise Exception(f"Failed to get search results from {self.name} after retries.")

    def _parse_results(self, results: dict) -> list[dict]:
        snippets = []

        if results.get("answerBox"):
            answer_box = results.get("answerBox", {})
            answer = answer_box.get("answer", answer_box.get("snippet", answer_box.get("snippetHighlighted", "")))
            return [
                {
                    "url": answer_box.get("link", ""),
                    "title": answer_box.get("title", ""),
                    "summ": "",
                    "content": answer,
                    "meta": {
                        "from": self.name,
                        "result_type": "answerBox",
                    },
                }
            ]

        if results.get("knowledgeGraph"):
            content = []
            kg = results.get("knowledgeGraph", {})
            title = kg.get("title")
            entity_type = kg.get("type")
            if entity_type:
                content.append(f"{title}: {entity_type}.")

            description = kg.get("description")
            if description:
                content.append(description)
            for attribute, value in kg.get("attributes", {}).items():
                content.append(f"{title} {attribute}: {value}.")
            snippets.append(
                {
                    "url": kg.get("website", ""),
                    "title": title,
                    "summ": description or "",
                    "content": "\n".join(content),
                    "meta": {
                        "from": self.name,
                        "result_type": "knowledgeGraph",
                    },
                }
            )

        for result in results[self.result_key_for_type[self.search_type]]:
            snippets.append(
                {
                    "url": result.get("link", ""),
                    "title": result.get("title", ""),
                    "summ": result.get("snippet", ""),
                    "meta": {
                        "from": self.name,
                    },
                }
            )
            if data := result.get("data"):
                snippets[-1]["published_date"] = data

        return self._filter_results(snippets)

    async def _search(self, search_term: str):
        headers = {
            "X-API-KEY": self.args.google_search_serper_key,
            "Content-Type": "application/json",
        }
        params = {
            "q": search_term,
            "gl": "cn",
            "hl": "zh-cn",
            "num": self.topk,
        }
        async with httpx.AsyncClient(timeout=None) as client:
            response = await client.post("https://google.serper.dev/search", headers=headers, params=params)
            response.raise_for_status()
            return response.json()

async def test_main():
    args = SearchConfig()
    args.google_search_serper_key = "796e689169109d9b0237074ca4319c73bdcc56f3"
    searcher = GoogleSerperSearch(args)
    results = await searcher.search("巴哈马—中国友好协会举办风筝节 时间")
    print(results)


if __name__ == "__main__":
    import asyncio

    asyncio.run(test_main())
