from typing import Optional, Union

import os
import json
import jsonschema
import requests

from pebble import ThreadPool
from time import sleep
from ray.util import pdb
from tqdm import tqdm

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import json_loads
from qwen_agent.settings import DEFAULT_WORKSPACE
from qwen_agent.tools.storage import KeyNotExistsError, Storage
from qwen_agent.utils.utils import hash_sha256

from qwen_agent.tools.multi_agent.utils import func_input_desc


URL = os.getenv('JINA_URL', 'https://r.jina.ai/')
KEY = os.getenv('JINA_KEY')
CONCURRENCY = int(os.getenv('JINA_CONCURRENCY', 5))
VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")
# VERBOSE = True


# @register_tool("PDFParserJina") # can not be used for directly invoking
class PDFParserJina(BaseTool):
    name = "PDFParserJina"
    # description = "A Google Scholar Search tool that searches relevant paper information from the Internet. It supports searching multiple queries simultaneously."
    
    # parameters = {
    #     "type": "object",
    #     "properties": {
    #         "queries": {
    #             "type": "array",
    #             "items": {"type": "string"},
    #             "description": "A list of search queries. Each query should be a clear and specific question or search term.",
    #         }
    #     },
    #     "required": ["queries"],
    # }
    

    def __init__(self, cfg: Optional[dict] = None):
        super().__init__(cfg)
        assert KEY is not None, "Please set the environment variable JINA_KEY if your want to use PDFParserJina"
        self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name))

        self.url = URL
        self.headers = {
            "Accept": "application/json",
            'Authorization': f'Bearer {KEY}',
            "X-With-Generated-Alt": "true",
            "X-Timeout": "60",
            "X-Token-Budget": "100000",
        }
        self.retry_times = 1
        self.db = Storage({'storage_root_path': self.data_root})

        # self.observation_note = "Note that, when providing multiple search queries:\n- Analyze each query separately\n- State whether relevant information was found\n- Provide findings or explicitly note when no relevant information was found"

    # def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False):
    #     tool_desc = func_input_desc(tool) + "\n" + self.observation_note
    #     tool_observation = []
    #     for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
    #         tool_observation.append(f"# Query {idx + 1}: {query}\n" + format_web_search_content(docs))
    #     if not empty_mode:
    #         observation = f"{tool_desc}\n\nTool Results:\n" + "\n\n".join(tool_observation)
    #     else:
    #         observation = "\n\n".join(tool_observation)
    #     return observation

    def single_query(self, params: dict):
        url = self.url + params['url']

        for _ in range(self.retry_times):
            try:
                response = requests.get(url, headers=self.headers, timeout=60).json()
                if response['code'] != 409:
                    if VERBOSE:
                        print(f"JINA Code: {response['code']} {response['name']} | [{params['url']}]->{response['readableMessage']}")
                    break
                if response['code'] != 200:
                    raise ValueError(f"JINA Code: {response['code']} {response['name']} | [{params['url']}]->{response['readableMessage']}")
                if response['data'] is None:
                    raise ValueError(f"[{params['url']}]->{str(response)}")
                
                if response['code'] == 200:
                    # {'title': '', 'description': '', 'url': 'https://arxiv.org/pdf/2410.06617', 'content': 'Published a', 'usage': {'tokens': 26767}}
                    params['result'] = response['data'] 
                    return params
                else:
                    if VERBOSE:
                        print(response)
            except requests.Timeout:
                if VERBOSE:
                    print("JINA Request timed out")
                # sleep(0.5)
            except Exception as e:
                if VERBOSE:
                    print(f"Error in single_query: {e}")
                # sleep(0.5)
        params['result'] = None
        return params

    def call(self, params: Union[str, dict], **kwargs) -> dict:
        try:
            params_json = self._verify_json_format_args(params)
        except ValueError as e:
            return {
                "success": False,
                "error_message": f"Parameter validation failed (JINA): {str(e)}"
            }
        try:
            url_list = params_json["urls"] # [{'paper_idx': 0, 'url': 'xxx'}, ...]
            url2paperidx = {item['url']: {"paper_idx": item['paper_idx'], "idx": idx} for idx, item in enumerate(url_list)}
            urls = list(url2paperidx.keys())

            query_list = []
            parsed_result_list = []
            for idx, url in enumerate(urls):
                cached_name = hash_sha256(json.dumps(url))
                try:
                    parsed_result = self.db.get(cached_name)
                    parsed_result_list.append({"index": idx, "cached_name": cached_name, "url": url, "result": json.loads(parsed_result)})
                except KeyNotExistsError:
                    query_list.append({"index": idx, "cached_name": cached_name, "url": url})

            outputs = []
            if query_list:
                with ThreadPool(max_workers=CONCURRENCY) as pool:
                    future = pool.map(self.single_query, query_list)
                    # outputs = list(tqdm(
                    #     future.result(),
                    #     total=len(query_list),
                    #     desc="PDF Parser JINA"
                    # ))
                    outputs = list(future.result())
                
                # cache search result
                for item in outputs:
                    if item['result'] is not None:
                        if len((item['result']['title'] + item['result']['content']).strip()) == 0:
                            # 如果  or 'warning' in item['result'] 依然保存，来减少通信开销
                            item['result'] = None

                    if item['result'] is not None:
                        self.db.put(item['cached_name'], json.dumps(item['result'], ensure_ascii=False, indent=2))

            final_results = sorted(parsed_result_list + outputs, key=lambda x: x['index'])
            # mask warning result
            for item in final_results:
                url = item['url']
                # try:
                # item['result'] could be None
                params_json['urls'][url2paperidx[url]['idx']]['parsed_results'] = None if item['result'] is None or 'warning' in item['result'] else item['result']
                # except Exception as e:
                #     pdb.set_trace()
                #     print()

            return {
                "success": True,
                "results": params_json,
                "params": params_json
            }

        except Exception as e:
            return {
                "success": False,
                "error_message": f"JINA failed: {str(e)}"
            }
    
    def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict:
        """Verify the parameters of the function call"""
        if isinstance(params, str):
            try:
                if strict_json:
                    params_json: dict = json.loads(params)
                else:
                    params_json: dict = json_loads(params)
            except json.decoder.JSONDecodeError:
                raise ValueError('Parameters must be formatted as a valid JSON!')
        else:
            params_json: dict = params
        if isinstance(self.parameters, list):
            for param in self.parameters:
                if 'required' in param and param['required']:
                    if param['name'] not in params_json:
                        raise ValueError('Parameters %s is required!' % param['name'])
        elif isinstance(self.parameters, dict):
            import jsonschema
            jsonschema.validate(instance=params_json, schema=self.parameters)
        else:
            raise ValueError
        return params_json

    @property
    def function(self) -> dict:  # Bad naming. It should be `function_info`.
        return {
            'name': self.name,
            'description': self.description,
            'parameters': self.parameters,
        }

if __name__ == "__main__":
    queries = {"urls": [
        "https://sure.sunderland.ac.uk/id/eprint/3553/1/Leonie_Marilynne_Solomons.pdf"
    ]}

    queries = {"url": "https://sure.sunderland.ac.uk/id/eprint/3553/1/Leonie_Marilynne_Solomons.pdf"}

    tool = PDFParserJina()
    result = tool.single_query(queries)
    print()