import asyncio
import aiofiles

import logging
from meta_researcher.tool.tools.search_engine.website_crawler.base_crawler import BaseContentFetch
from meta_researcher.tool.tools.search_engine.website_crawler.simple_crawler import SimpleContentFetch
from itertools import zip_longest, chain
from meta_researcher.tool.tools.search_engine.base_search import BaseSearch, SearchConfig
from meta_researcher.tool.tools.search_engine.bing_search import BingSearch
from meta_researcher.tool.tools.search_engine.bocha_search import BochaSearch
from meta_researcher.tool.tools.search_engine.google_search import GoogleSearch
from meta_researcher.tool.tools.search_engine.google_serper_search import GoogleSerperSearch
from meta_researcher.tool.tools.search_engine.google_bing_optional_search import GoogleBingOptionalSearch

ALL_SEAECH_ENGINE_MAP: dict[str, type[BaseSearch]] = {
    BingSearch.__name__: BingSearch,
    GoogleSearch.__name__: GoogleSearch,
    GoogleSerperSearch.__name__: GoogleSerperSearch,
    BochaSearch.__name__: BochaSearch,
    GoogleBingOptionalSearch.__name__: GoogleBingOptionalSearch,
}

def dedup_webpages(webpages: list[dict]):
    seen = set()
    rets = []
    for r in webpages:
        if "url" not in r:
            rets.append(r)
            continue
        url = r["url"]
        if url not in seen:
            seen.add(url)
            rets.append(r)
    logging.info(f"dedup remove {len(webpages) - len(rets)} webpages")
    return rets

class Retriever:
    def __init__(self, engine: str, base_url: str = "", optional_search_engine: str = "", search_method: str = "", max_concurrency: int = 16):
        self.engine_name = engine
        if self.engine_name == "GoogleBingOptionalSearch":
            self.engine = SearchEngine(SearchConfig(search_engine_type=engine, base_url=base_url,optional_search_engine=optional_search_engine, search_method=search_method))
        else:
            self.engine = SearchEngine(SearchConfig(search_engine_type=engine))
        self.fetcher = SimpleContentFetch(SearchConfig(search_engine_type=engine)) #SimpleContentFetch
        self.semaphore = asyncio.Semaphore(max_concurrency)
        self.semaphore_fetch = asyncio.Semaphore(max_concurrency)
        self.timeout = 3000
        self.topk = 10

    async def controlled_search(self, query, topk):
        async with self.semaphore: 
            return await asyncio.wait_for(self.engine.batch_search(query=query, topk=topk), timeout=self.timeout)  

    async def retrieve(self, queries: list[str]) -> str:
        if len(queries) == 0:
            return []
        all_search_tasks = [asyncio.create_task(self.controlled_search(query=q, topk=self.topk)) for q in queries]
        webpages = await asyncio.gather(*all_search_tasks)
        return webpages

    async def controlled_fetch(self, webpages):
        async with self.semaphore_fetch: 
            return await asyncio.wait_for(self.fetch_web_contents(webpages=webpages), timeout=self.timeout) 
    
    async def batch_fetch_web_contents(self, webpage: list) -> list[dict]:
        webpages = []
        for webpage_ in webpage:
            rets = await self.fetch_web_contents(webpages=webpage_)
            webpages.append(rets)
        return webpages

    async def fetch_web_contents(self, webpages: list) -> list[dict]:
        urls = {}
        for idx, webpage in enumerate(webpages):
            if "content" not in webpage:
                urls[idx] = webpage["url"]
        if not urls:
            return webpages[:10]

        all_contents = await self.fetcher.batch_fetch(urls=[w for w in urls.values()])
        for (web_content, fetch_meta), idx in zip(all_contents, urls):
            if web_content:
                webpages[idx]["content"] = web_content
                webpages[idx]["meta"].update(fetch_meta)

        rets = []
        for webpage in webpages:
            if "content" in webpage:
                rets.append(webpage)
            if len(rets) >= 10:
                break
        return rets

class SearchEngine:
    def __init__(self, args: SearchConfig):
        self.args = args
        self.searcher: BaseSearch = ALL_SEAECH_ENGINE_MAP[args.search_engine_type](args)
        self.fetcher: BaseContentFetch = SimpleContentFetch(args)
        self.log_fp = None

    async def close(self):
        await self.fetcher.close()

    async def batch_search(self, query: str, topk: int = 6, searcher: BaseSearch | None = None) -> list[dict]:
        s = searcher or self.searcher
        webpages = []
        try:
            search_results = await s.batch_search(query)
            webpages = [search_result["search_results"] for search_result in search_results]
            webpages = self.flatten_interleaved(webpages)
            logging.debug(f"all search urls: {[r['url'] for r in webpages]}")
        except Exception:
            logging.warning(f'search for "{query}" completed')
        return webpages[:topk]

    async def fetch_web_contents(self, webpages: list, query: str | None = None) -> list[dict]:
        urls = {}
        for idx, webpage in enumerate(webpages):
            if "content" not in webpage:
                urls[idx] = webpage["url"]
        if not urls:
            return webpages[: self.args.topk]

        if self.args.log_file and not self.log_fp:
            self.log_fp = await aiofiles.open(self.args.log_file, "a", encoding="utf-8")

        all_contents = await self.fetcher.batch_fetch(urls=[w for w in urls.values()])
        for fetch_result, idx in zip(all_contents, urls):
            web_content, fetch_meta = fetch_result.content, fetch_result.meta
            if self.log_fp:
                await self.log_fp.write(f"{fetch_meta.model_dump_json(exclude_none=True)}\n")
                await self.log_fp.flush()
            if web_content:
                webpages[idx]["content"] = web_content
                webpages[idx]["meta"].update(fetch_meta.model_dump(exclude_none=True))

        content_cnt = 0
        rets = []
        for webpage in webpages:
            rets.append(webpage)
            if "content" in webpage:
                content_cnt += 1
            if content_cnt >= self.args.topk:
                break
        return rets
    
    def flatten_interleaved(self, lst):
        transposed = zip_longest(*lst, fillvalue=None)
        return [item for item in chain(*transposed) if item is not None]
