from crawl4ai import AsyncWebCrawler, CrawlerRunConfig

from meta_researcher.tool.tools.search_engine.base_search import SearchConfig
from meta_researcher.tool.tools.search_engine.website_crawler.base_crawler import BaseContentFetch, FetchResult, FetchResultMeta


class Craw4aiContentFetch(BaseContentFetch):
    def __init__(self, args: SearchConfig):
        self.args = args

        self.crawler = None
        self.config = CrawlerRunConfig(
            excluded_tags=["script", "style", "nav", "head", "img", "footer", "iframe", "header"],
            exclude_external_links=True,
            exclude_internal_links=True,
            exclude_external_images=True,
            only_text=True,
        )

    async def _init_crawler(self):
        if self.crawler is None:
            self.crawler = AsyncWebCrawler()
            await self.crawler.start()

    async def close(self):
        if self.crawler:
            await self.crawler.close()

    async def fetch(self, url: str, query: str | None = None) -> FetchResult:
        await self._init_crawler()
        assert self.crawler is not None

        meta = FetchResultMeta(url=url, query=query, crawler="crawl4ai")
        result = await self.crawler.arun(url, config=self.config)
        markdown_content = result.markdown or ""  # type: ignore
        return FetchResult(content=markdown_content, meta=meta)

    async def batch_fetch(self, urls: list[str], querys: str | None = None) -> list[FetchResult]:
        await self._init_crawler()
        assert self.crawler is not None
        results = await self.crawler.arun_many(urls, config=self.config)
        rets = []
        for result, url in zip(results, urls):  # type: ignore
            markdown_content = result.markdown or ""

            meta = FetchResultMeta(url=url, query=querys, crawler="crawl4ai")
            rets.append(FetchResult(content=markdown_content, meta=meta))
        return rets


async def main_test():
    fetcher = Craw4aiContentFetch(SearchConfig())
    url = "https://zhuanlan.zhihu.com/p/1890804221336614802"
    result = await fetcher.fetch(url)
    print(result.content)
    result = await fetcher.fetch(url)
    print(result.content)
    result = await fetcher.fetch(url)
    print(result.content)


if __name__ == "__main__":
    import asyncio

    asyncio.run(main_test())
