from typing import Optional, Union

import os
import json
import jsonschema
import requests

from pebble import ThreadPool
from time import sleep
from ray.util import pdb
from tqdm import tqdm

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import json_loads
from qwen_agent.settings import DEFAULT_WORKSPACE
from qwen_agent.tools.storage import KeyNotExistsError, Storage
from qwen_agent.utils.utils import hash_sha256

from qwen_agent.tools.multi_agent.utils import func_input_desc


URL = os.getenv('SCHOLAR_SEARCH_URL')
KEY = os.getenv('SCHOLAR_SEARCH_KEY')
CONCURRENCY = int(os.getenv('SCHOLAR_SEARCH_CONCURRENCY', 5))
NUM = int(os.getenv('SCHOLAR_NUM', 5))
VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")
# VERBOSE = True

# @register_tool("GoogleScholar")
class GoogleScholarTool(BaseTool):
    name = "GoogleScholarTool"
    # description = "A Google Scholar Search tool that searches relevant paper information from the Internet. It supports searching multiple queries simultaneously."
    
    # parameters = {
    #     "type": "object",
    #     "properties": {
    #         "queries": {
    #             "type": "array",
    #             "items": {"type": "string"},
    #             "description": "A list of search queries. Each query should be a clear and specific question or search term.",
    #         }
    #     },
    #     "required": ["queries"],
    # }
    
    def __init__(self, cfg: Optional[dict] = None):
        super().__init__(cfg)
        assert KEY is not None, "Please set the environment variable SCHOLAR_SEARCH_KEY if your want to use Google Scholar Search"
        self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name))

        self.url = URL
        self.headers = {
            'X-AK': KEY,
            'Content-Type': 'application/json'
        }
        self.retry_times = 5
        self.db = Storage({'storage_root_path': self.data_root})

        # self.observation_note = "Note that, when providing multiple search queries:\n- Analyze each query separately\n- State whether relevant information was found\n- Provide findings or explicitly note when no relevant information was found"
        # self.observation_note = "Note that:\n- Analyze each query separately\n- For each query, specify which exact query you are extracting information for\n- For queries with no relevant information, explicitly state \"No relevant information found for the query: [paste the exact query here]\"\n"
        self.observation_note = ""

    def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False):
        tool_desc = func_input_desc(tool) + "\n" + self.observation_note
        tool_observation = []
        for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
            tool_observation.append(f"# Query {idx + 1}: {query}\n" + format_web_search_content(docs))
        if not empty_mode:
            observation = f"{tool_desc}\n\nTool Results:\n" + "\n\n".join(tool_observation)
        else:
            observation = "\n\n".join(tool_observation)
        return observation

    def single_query(self, params: dict):
        payload = params['payload']
        for _ in range(self.retry_times):
            try:
                response = requests.post(self.url, headers=self.headers, json=payload, timeout=150).json()
                if "errorType" in response:
                    raise ValueError(f"[{payload['uq']}]->{response['errorType']}")
                if len(response['data']['originalOutput']['organic']) == 0:
                    raise ValueError("No results found")
                if response['success']:
                    # return {json.dumps(payload): response}
                    params['result'] = response['data']['originalOutput']['organic']
                    return params
            except requests.Timeout:
                if VERBOSE:
                    print("Request timed out")
                sleep(0.5)
            except Exception as e:
                if VERBOSE:
                    print(f"Error in single_query: {e}")
                sleep(0.5)
        params['result'] = None
        return params

    def call(self, params: Union[str, dict], **kwargs) -> dict:
        try:
            params_json = self._verify_json_format_args(params)
        except ValueError as e:
            return {
                "success": False,
                "error_message": f"Parameter validation failed (Scholar): {str(e)}"
            }
        try:
            queries = params_json["queries"] # queries is a list of strings
            if isinstance(queries, str):
                queries = [queries]

            # special case
            if len(queries) == 0:
                return {
                    "success": False,
                    "error_message": f"Scholar Search failed: Params empty {str(queries)}"
                }

            post_params = [
                {
                    "query": q,
                    "num": NUM,
                    "extendParams": {
                        "country": "cn",
                        "locale": "zh-cn",
                        "location": "United States",
                        "page": 1
                    },
                    "platformInput": {
                        "model": "google-search"
                    }
                }
                for q in queries
            ]

            query_list = []
            search_result_list = []
            for idx, payload in enumerate(post_params):
                cached_name = hash_sha256(json.dumps(payload))
                try:
                    search_result = self.db.get(cached_name)
                    search_result_list.append({"index": idx, "cached_name": cached_name, "payload": payload, "result": json.loads(search_result)})
                except KeyNotExistsError:
                    query_list.append({"index": idx, "cached_name": cached_name, "payload": payload})

            outputs = []
            if query_list:
                with ThreadPool(max_workers=CONCURRENCY) as pool:
                    future = pool.map(self.single_query, query_list)
                    # outputs = list(tqdm(
                    #     future.result(),
                    #     total=len(query_list),
                    #     desc="Scholar searching"
                    # ))
                    outputs = list(future.result())
                
                # cache search result
                for item in outputs:
                    if item['result'] is not None:
                        filtered_result = []
                        for i in item['result']:
                            if len((i['title'] + i['snippet']).strip()) != 0:
                                filtered_result.append(i)
                        item['result'] = None if len(filtered_result) == 0 else filtered_result

                    if item['result'] is not None:
                        assert isinstance(item['result'], list) and len(item['result']) > 0, f"search result should be a non-empty result, but now {item['result']}"
                        self.db.put(item['cached_name'], json.dumps(item['result'], ensure_ascii=False, indent=2))

            final_results = sorted(search_result_list + outputs, key=lambda x: x['index'])
            return {
                "success": True,
                "results": final_results,
                "params": params_json
            }

        except Exception as e:
            return {
                "success": False,
                "error_message": f"Scholar Search failed: {str(e)}"
            }
    
    def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict:
        """Verify the parameters of the function call"""
        if isinstance(params, str):
            try:
                if strict_json:
                    params_json: dict = json.loads(params)
                else:
                    params_json: dict = json_loads(params)
            except json.decoder.JSONDecodeError:
                raise ValueError('Parameters must be formatted as a valid JSON!')
        else:
            params_json: dict = params
        if isinstance(self.parameters, list):
            for param in self.parameters:
                if 'required' in param and param['required']:
                    if param['name'] not in params_json:
                        raise ValueError('Parameters %s is required!' % param['name'])
        elif isinstance(self.parameters, dict):
            import jsonschema
            jsonschema.validate(instance=params_json, schema=self.parameters)
        else:
            raise ValueError
        return params_json

    def save_cache(self):
        """automatically save cache when the tool is deleted"""
        if hasattr(self, 'search_engine'):
            self.search_engine.save_cache()

    @property
    def function(self) -> dict:  # Bad naming. It should be `function_info`.
        return {
            'name': self.name,
            'description': self.description,
            'parameters': self.parameters,
        }


def format_web_search_content(docs: list) -> str:
    # return "\n".join([f"## Document {idx + 1}:\n title: {doc['title']}\nsnippet: {doc['text']}\npage: {doc['page']}" for idx, doc in enumerate(docs)])
    return '```\n{}\n```'.format("\n\n".join(
        [f"[{idx + 1}]: \"{doc['title']}\"\n{doc['text']}\n{doc['page']}".strip() for idx, doc in enumerate(docs)]))


if __name__ == "__main__":
    queries = {"queries": ["toolevo", "alphamath"]}

    tool = GoogleScholarTool()
    result = tool.call(queries)
    print()