from typing import Optional, Union, List

import os
import json
import jsonschema

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import json_loads
from qwen_agent.settings import DEFAULT_WORKSPACE

# from .search_engine import Search_Engine
from qwen_agent.tools.multi_agent.googlesearch.search_engine import Search_Engine
from qwen_agent.tools.multi_agent.utils import func_input_desc, readpage_func_desc
from ray.util import pdb

URL = os.getenv('QWEN_SEARCH_URL')
KEY = os.getenv('QWEN_SEARCH_KEY')
SCENE = os.getenv('QWEN_SEARCH_SCENE')
TOPK = os.getenv('GOOGLE_TOPK', 10)
USERNAME = os.getenv('GOOGLE_USERNAME', 'test')
CONCURRENCY = int(os.getenv('GOOGLE_CONCURRENCY', 5))



@register_tool("GoogleSearch")
class GoogleSearch(BaseTool):
    name = "GoogleSearch"
    description = "A Google search tool that searches information from the Internet. It supports searching multiple queries simultaneously."

    parameters = {
        "type": "object",
        "properties": {
            "queries": {
                "type": "array",
                "items": {"type": "string"},
                "description": "A list of search queries. Each query should be a clear and specific question or search term.",
            }
        },
        "required": ["queries"],
    }

    def __init__(self, cfg: Optional[dict] = None):
        super().__init__(cfg)
        self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name))
        # initialize search engine
        self.search_engine = Search_Engine(
            server_url=URL,
            search_scene=SCENE,
            top_k=TOPK,
            authorization=KEY,
            user_name=USERNAME,
            search_engine_cache_file=self.data_root,
            concurrency=CONCURRENCY,
        )

        # self.observation_note = "Note that, when providing multiple search queries:\n- Analyze each query separately\n- State whether relevant information was found\n- Provide findings or explicitly note when no relevant information was found"
        # self.observation_note = "Note that:\n- Analyze each query separately\n- For each query, specify which exact query you are extracting information for\n- For queries with relevant information, present the findings\n- For queries with no relevant information, specify that no results were found\n"
        # self.observation_note = "Note that:\n- Analyze each query separately\n- For each query, specify which exact query you are extracting information for\n- For queries with no relevant information, explicitly state \"No relevant information found for the query: [paste the exact query here]\"\n"
        self.observation_note = ""

    def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False, readpage: bool=False, max_observation_length: int=None, tokenizer=None):
        if not readpage:
            tool_desc = func_input_desc(tool) + "\n" + self.observation_note
            tool_observation = []
            for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
                tool_observation.append(f"# Query {idx + 1}: {query}\n" + format_web_search_content(docs))
            if not empty_mode:
                observation = f"{tool_desc}\n\nTool Results:\n" + "\n\n".join(tool_observation)
            else:
                observation = "\n\n".join(tool_observation)
            return observation
        else:
            assert empty_mode == False, f"can not use empty mode when using readpage"
            tool_desc = readpage_func_desc(tool) + "\n" + self.observation_note
            pre_len = len(tokenizer.encode(tool_desc, truncation=False))
            max_observation_length -= pre_len
            
            tool_observation = []
            for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
                tool_observation.append({
                    'query_idx': idx,
                    'content': f"# Query {idx + 1}: {query}\n" + format_web_search_content(docs),
                    'tool_results': docs,
                })
            observation = self.merge_observations(tool_observation, max_observation_length, tokenizer)
            for obs in observation:
                obs['content'] = obs['content'] + "\n\n" + tool_desc
            return observation

    def merge_observations(self, tool_observation, max_observation_length, tokenizer):
        assert len(tool_observation) != 0, f"tool_observation must not be empty"
        
        merged_results = []
        current_merge = {
            'query_idxs': [],
            'content': []
        }
        current_length = 0
        
        for obs in tool_observation:
            content_length = len(tokenizer.encode(obs['content'], truncation=False))
            
            # 如果加入当前内容后超出长度限制
            if current_length + content_length > max_observation_length:
                # 保存当前合并结果
                if current_merge['content']:
                    current_merge['content'] = "\n\n".join(current_merge['content'])
                    merged_results.append(current_merge)
                
                # 开始新的合并
                current_merge = {
                    'query_idxs': [obs['query_idx']],
                    'content': [obs['content']]
                }
                current_length = content_length
            else:
                # 合并内容
                current_merge['query_idxs'].append(obs['query_idx'])
                current_merge['content'].append(obs['content'])
                current_length += content_length
        
        # 添加最后一个合并结果
        if current_merge['content']:
            current_merge['content'] = "\n\n".join(current_merge['content'])
            merged_results.append(current_merge)

        return merged_results

    def call(self, params: Union[str, dict], **kwargs) -> dict:
        try:
            params_json = self._verify_json_format_args(params)
        except ValueError as e:
            return {
                "success": False,
                "error_message": f"Parameter validation failed: {str(e)}"
            }
        try:
            # if isinstance(params_json, list):
            #     print()
            queries = params_json["queries"] # queries is a list of strings
            if isinstance(queries, str):
                queries = [queries]

            # special case
            if len(queries) == 0:
                return {
                    "success": False,
                    "error_message": f"Search failed: Params empty {str(queries)}"
                }

            results = self.search_engine.retrieve(queries, [0] * len(queries))

            return {
                "success": True,
                "results": results,
                "params": params_json
            }

        except Exception as e:
            return {
                "success": False,
                "error_message": f"Search failed: {str(e)}"
            }

    def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict:
        """Verify the parameters of the function call"""
        if isinstance(params, str):
            try:
                if strict_json:
                    params_json: dict = json.loads(params)
                else:
                    params_json: dict = json_loads(params)
            except json.decoder.JSONDecodeError:
                raise ValueError('Parameters must be formatted as a valid JSON!')
        else:
            params_json: dict = params
        if isinstance(self.parameters, list):
            for param in self.parameters:
                if 'required' in param and param['required']:
                    if param['name'] not in params_json:
                        raise ValueError('Parameters %s is required!' % param['name'])
        elif isinstance(self.parameters, dict):
            import jsonschema
            jsonschema.validate(instance=params_json, schema=self.parameters)
        else:
            raise ValueError
        return params_json


    @property
    def function(self) -> dict:  # Bad naming. It should be `function_info`.
        return {
            # 'name_for_human': self.name_for_human,
            'name': self.name,
            'description': self.description,
            'parameters': self.parameters,
            # 'args_format': self.args_format
        }


def format_web_search_content(docs: list) -> str:
    # return "\n".join([f"## Document {idx + 1}:\n title: {doc['title']}\nsnippet: {doc['text']}\npage: {doc['page']}" for idx, doc in enumerate(docs)])
    return '```\n{}\n```'.format("\n\n".join(
        [f"[webpage {idx + 1} begin]\nTitle: {doc['title']}\n{doc['text']}\n{doc['page']}\n[webpage {idx + 1} end]".strip() for idx, doc in enumerate(docs)]))

if __name__=="__main__":
    tool = GoogleSearch()
    params = "{\"queries\": [\"Classical Greek non-classical word in Passage 1: οἴμοι νεκροὺς ἐγὼ δὲ οὔκ εἶμαι νεκρός μὴ οὖν κόμιζέ με εἰς Ἀϊδου\", \"Classical Greek non-classical word in Passage 2: τὰς πέτρας ἀποβλέπων κελεύω τοὺς ἑταίρους ἐκπέμπειν ὄρνιθα ἐκ τῆς νηός\"]}"
    print(tool.call(params))