import os
import os.path as osp
import json
import re
import requests
from time import sleep
from typing import List
from pebble import ThreadPool
from tqdm import tqdm
# import fcntl
from filelock import FileLock
import tempfile

from qwen_agent.tools.storage import KeyNotExistsError, Storage
from qwen_agent.utils.utils import hash_sha256

from ray.util import pdb

timeout_duration = 360

MAXPAGELENGTH = int(os.getenv('GOOGLE_MAXPAGELENGTH', 4000))
VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")

class Search_Engine():
    def __init__(self, server_url, search_scene, top_k, authorization, user_name, search_engine_cache_file, concurrency=5, **kwargs):
        self.server_url = server_url  # TODO: 或许能够通过多个url进行负载均衡
        self.search_scene = search_scene
        self.headers = {
            'Content-Type': 'application/json',
            '__d_head_qto': '8000',
            '__d_head_app': user_name,
            "Authorization": f"Bearer {authorization}",
            "Host": "pre-nlp-cn-hangzhou.aliyuncs.com",
        }
        self.top_k = int(top_k)
        self.retry_times = 10
        self.concurrency = concurrency
        self.kwargs = kwargs
        self.db = Storage({'storage_root_path': search_engine_cache_file})

    def single_query(self, params: dict):
        payload = params['payload']
        for _ in range(self.retry_times):
            try:
                response = requests.post(self.server_url, headers=self.headers, json=payload, timeout=60).json()
                if "errorType" in response:
                    raise ValueError(f"[{payload['uq']}]->{response['errorType']}")
                if len(response['data']['docs']) == 0:
                    raise ValueError("No results found")
                if response['status'] == 0:
                    params['result'] = response
                    # return {json.dumps(payload): response}
                    return params
            except requests.Timeout:
                if VERBOSE:
                    print("Request timed out")
                sleep(2)
            except Exception as e:
                if VERBOSE:
                    print(f"Error in single_query: {e}")
                sleep(2)
        params['result'] = None
        # return {json.dumps(payload): response}
        return params

    def retrieve(self, messages: List[str], retrieve_times: List[int]) -> List[List[str]]:
        query_list, search_result_list = [], []
        for idx, (message, r_time) in enumerate(zip(messages, retrieve_times)):
            n_docs = (r_time + 1) * self.top_k
            payload = {
                "rid": "",
                "scene": self.search_scene,
                "uq": message,
                "debug": False,
                "fields": [],
                "page": 1,
                "rows": n_docs,
                "customConfigInfo": {
                    "readpage": True,
                    "inspection": False, #关闭绿网
                    "readpageConfig": {"tokens": 4000, "topK": n_docs, "onlyCache": False},
                },
            }
            cached_name = hash_sha256(json.dumps(payload))
            try:
                search_result = self.db.get(cached_name)
                search_result_list.append({"index": idx, "cached_name": cached_name, "payload": payload, "result": json.loads(search_result)})
            except KeyNotExistsError:
                query_list.append({"index": idx, "cached_name": cached_name, "payload": payload})


        outputs = []
        if query_list:
            with ThreadPool(max_workers=self.concurrency) as pool:
                future = pool.map(self.single_query, query_list)
                # outputs = list(tqdm(
                #     future.result(),
                #     total=len(query_list),
                #     desc="Searching queries"
                # ))
                outputs = list(future.result())

            # cache search result
            for item in outputs:
                # TODO: 考虑失败或者其他case？
                if item['result'] is not None and "errorType" not in item['result'] and item['result']['success'] == True:
                    self.db.put(item['cached_name'], json.dumps(item['result'], ensure_ascii=False, indent=2))

        final_results = sorted(search_result_list + outputs, key=lambda x: x['index'])

        selected_passages = []
        for r_time, search_result in zip(retrieve_times, final_results):
            response = search_result['result']
            if response is not None and 'data' in response:
                passages = [{"id": item['_id'], "text": _rm_html(item['snippet']), "title": _rm_html(item['title']), "page": _rm_html(item.get('web_main_body', "")[:MAXPAGELENGTH]).strip(), 'url': item['url']} for item in response['data']['docs']] # will be [] if no results found
                # passages = []
                # for item in response['data']['docs']:
                #     passage = {
                #         "id": item['_id'],
                #         "text": item['snippet'],
                #         "title": item['title'],
                #         "page": item.get('web_main_body', "").strip(),
                #         'url': item['url']
                #     }
                #     passages.append(passage)
            else:
                passages = [[]]
            if passages == [[]] or len(passages) == 0:
                # search_result['payload']['uq']
                passages = [{"id": "query result empty", "text": "None relevant information under the query of `" + " ".join(search_result['payload']['uq'].split()[:5]) + "... `", "title": search_result['payload']['uq'], "page": "", 'url': "None relevant information"}]

            selected_passages.append(passages[r_time * self.top_k:self.top_k * (r_time + 1)])

        return selected_passages

def _rm_html(text: str) -> str:
    _HTML_TAG_RE = re.compile(r" ?</?(a|span|em|br).*?> ?")
    text = text.replace("\xa0", " ")
    text = text.replace("\t", "")  # quark uses \t to split chinese words
    text = text.replace("...", "……")
    text = _HTML_TAG_RE.sub("", text)
    text = text.strip()
    if text.endswith("……"):
        text = text[: -len("……")]
    return text


