from dataclasses import dataclass
import asyncio
from meta_researcher.tool.tools.search_engine.log import logger
from meta_researcher.tool.tools.search_engine.website_crawler.content_process import (
    URL_BLACK_LISTS,
)


@dataclass
class SearchConfig:
    search_engine_type: str = "GoogleSerperSearch"

    bing_search_key: str = "XXXXXXXXXXXXXXXXXXXXXXXX"
    google_search_key: str = "XXXXXXXXXXXXXXXXXXXXXXXX"
    google_search_serper_key: str = "XXXXXXXXXXXXXXXXXXXXXXXX"
    bocha_search_key: str = "XXXXXXXXXXXXXXXXXXXXXXXX"
    zhipu_search_key: str = "XXXXXXXXXXXXXXXXXXXXXXXX"

    log_file: str | None = None
    topk: int = 20
    pages: int = 2
    timeout: int | float = 5
    max_retry: int = 2
    ## RAG config
    base_url: str = ""
    optional_search_engine: str = ""
    search_method: str = ""


class BaseSearch:
    def __init__(self, args: SearchConfig):
        self.args = args
        self.topk = 10
        self.name = "BaseSearch"

    async def search(self, query: str) -> list[dict]:
        raise NotImplementedError

    async def batch_search(self, querys: list[str]) -> list[list[dict]]:
        all_tasks = [asyncio.create_task(self.search(query)) for query in querys]
        try:
            all_results = await asyncio.gather(*all_tasks)
            return all_results
        except asyncio.TimeoutError:
            logger.warning(f"{self.name} timeout occurred! Partial results")
            all_results = []
            for task in all_tasks:
                if task.done() and not task.cancelled():
                    all_results.append(task.result())
                else:
                    task.cancel()
                    all_results.append([])
            return all_results

    def check_url(self, url):
        for f in URL_BLACK_LISTS:
            if f in url:
                return False
        return True

    def _filter_results(self, results: list[dict]) -> list[dict]:
        filtered_results = []
        url_set = set()
        for result in results:
            url = result.get("url", "")
            if "content" not in result and not self.check_url(url):
                logger.info(f"skip url: {url}")
                continue
            if url not in url_set:
                filtered_results.append(result)
                url_set.add(url)
            if len(filtered_results) >= self.topk:
                break
        return filtered_results
