from typing import Optional, Union

import os
import json
import jsonschema
import requests

from pebble import ThreadPool
from time import sleep
from ray.util import pdb
from tqdm import tqdm

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import json_loads
from qwen_agent.settings import DEFAULT_WORKSPACE
from qwen_agent.tools.storage import KeyNotExistsError, Storage
from qwen_agent.utils.utils import hash_sha256

from qwen_agent.tools.multi_agent.utils import func_input_desc, readpage_func_desc
from qwen_agent.tools.multi_agent.scholarsearch import GoogleScholarTool, PDFParserJina


GTE_CONCURRENCY = int(os.getenv('GTE_CONCURRENCY', 32))
VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")
# VERBOSE = True

ONLY_SCHOLAR = os.getenv('ONLY_SCHOLAR', "false").lower() in ["true", '1', 'yes']

READPAGE = os.getenv('READPAGE', "true").lower() in ["true", '1', 'yes']


@register_tool("GoogleScholar")
class GoogleScholar(BaseTool):
    name = "GoogleScholar"
    description = "A Google Scholar search tool that can search academic papers, technical terms, or other scholarly information. It supports searching multiple queries simultaneously."
    
    parameters = {
        "type": "object",
        "properties": {
            "queries": {
                "type": "array",
                "items": {"type": "string"},
                "description": "A list of search queries. Each query can be a paper title, author name, research topic, or technical term.",
            }
        },
        "required": ["queries"],
    }
    

    def __init__(self, cfg: Optional[dict] = None):
        super().__init__(cfg)
        # tag = "_v1_only_scholar" if ONLY_SCHOLAR else "_v1"
        # self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name + tag))

        self.googlescholar = GoogleScholarTool()
        self.pdfparser = PDFParserJina()
        # self.rerank = GTERerank()

        # self.retry_times = 5
        # self.db = Storage({'storage_root_path': self.data_root})

        # self.observation_note = "Note that:\n1. Please analyze each query independently, when providing multiple queries\n2. For each query, if information is found, provide key findings; if not, explicitly state no relevant results"
        self.observation_note = ""

    def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False, readpage: bool=False, max_observation_length: int=None, tokenizer=None):
        if not readpage:
            tool_desc = func_input_desc(tool) + "\n" + self.observation_note
            tool_observation = []
            for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
                assert query == docs['payload']['query'], f"query should be equal, but now query: {query}\npayload:\n{docs['payload']}"
                if docs['result'] is None:
                    # deal with paper['result'] is None
                    tool_observation.append(f"# Query {idx + 1}: {query}\n" + "No information.")
                else:
                    tool_observation.append(f"# Query {idx + 1}: {query}\n" + format_papers_to_str(docs))
            if not empty_mode:
                observation = f"{tool_desc}\n\nTool Results:\n" + "\n\n".join(tool_observation)
            else:
                observation = "\n\n".join(tool_observation)
            return observation
        else:
            assert empty_mode == False, f"can not use empty mode when using readpage"
            tool_desc = readpage_func_desc(tool) + "\n" + self.observation_note
            self.observation_note
            pre_len = len(tokenizer.encode(tool_desc, truncation=False))
            max_observation_length -= pre_len

            tool_observation = []
            for idx, (query, docs) in enumerate(zip(tool_dict['queries'], tool_results)):
                assert query == docs['payload']['query'], f"query should be equal, but now query: {query}\npayload:\n{docs['payload']}"

                query_observations = []
                query_prefix = f"# Query {idx + 1}: {query}\n"
                if docs['result'] is None:
                    # deal with paper['result'] is None
                    query_observations.append({"paper_indices": None, 'content': "No information.", "papers": None})

                else:
                    # 按每个pdf拼接，
                    # 一个query对应多个pdf
                    # 没有pdf的content的完全可以拼接在一起
                    # 但是保持query的独立性
                    # 按title划分
                    # content = f"# Query {idx + 1}: {query}\n" + 
                    contents = format_papers_to_str_in_readpage(docs)
                    query_observations = self.merge_papers(contents, max_observation_length, tokenizer)
                    # merged_contents 每一个对应system1一次调用
                
                for content in query_observations:
                    try:
                        tool_observation.append({
                            'query_idx': idx,
                            'paper_index': content['paper_indices'],
                            'papers': content['papers'],
                            'content': query_prefix + content['content'] + "\n\n" + tool_desc,
                            'tool_results': docs,
                        })
                    except:
                        pdb.set_trace()
                
            return tool_observation

    def merge_papers(self, papers, max_observation_length, tokenizer):
        def get_token_length(text):
            return len(tokenizer.encode(text, truncation=False))

        def truncate_content(content, max_length, tokenizer):
            tokens = tokenizer.encode(content, truncation=False)
            if len(tokens) <= max_length:
                return content
            return tokenizer.decode(tokens[:max_length])

        # 计算每个文档的长度 (都是一个query对应的所有的paper)
        paper_lengths = [(i, paper, get_token_length(paper)) 
                        for i, paper in enumerate(papers)]
        
        # 分离需要截断的长文档和可以合并的短文档
        long_papers = []
        short_papers = []
        for idx, paper, length in paper_lengths:
            if length > max_observation_length:
                long_papers.append((idx, paper))
            else:
                short_papers.append((idx, paper, length))
        
        merged_results = []

        # 处理长文档
        for idx, paper in long_papers:
            merged_results.append({
                'paper_indices': [idx],
                'papers': [truncate_content(paper, max_observation_length, tokenizer)]
            })
        
        # 使用First Fit Decreasing (FFD)算法处理短文档
        # 1. 按长度降序排序
        short_papers.sort(key=lambda x: x[2], reverse=True)
        
        bins = []  # 每个bin是一个(剩余空间, [文档索引], [文档内容])的元组
        
        # 2. 遍历每个文档，放入能容纳它的第一个bin
        for idx, paper, length in short_papers:
            placed = False
            # 尝试放入现有的bin
            for bin_idx, bin_info in enumerate(bins):
                remaining_space, indices, contents = bin_info
                if remaining_space >= length:
                    # 可以放入这个bin
                    bins[bin_idx] = (
                        remaining_space - length,
                        indices + [idx],
                        contents + [paper]
                    )
                    placed = True
                    break
            
            if not placed:
                # 创建新的bin
                bins.append((
                    max_observation_length - length,
                    [idx],
                    [paper]
                ))
        
        # 将bins转换为最终结果
        for _, indices, contents in bins:
            sorted_pairs = sorted(zip(indices, contents), key=lambda x: x[0])
            sorted_indices, sorted_contents = zip(*sorted_pairs)
            merged_results.append({
                'paper_indices': list(sorted_indices),
                'papers': list(sorted_contents)
            })
        
        # 按原始文档索引排序
        merged_results.sort(key=lambda x: x['paper_indices'][0])

        paper_index = 0
        for result in merged_results:
            result['content'] = [
                f'[Paper {paper_index + idx + 1} begin]\n' 
                + p
                + f'\n[Paper {paper_index + idx + 1} end]' 
                for idx, p in enumerate(result['papers'])]
            paper_index += len(result['papers'])

            result['content'] = "\n\n".join(result['content'])

        return merged_results

    def call(self, params: Union[str, dict], **kwargs) -> dict:
        try:
            params_json = self._verify_json_format_args(params)
        except ValueError as e:
            return {
                "success": False,
                "error_message": f"Parameter validation failed: {str(e)}"
            }
        try:
            # 这里不建缓存，是因为无论是pdf解析，还是search都有自己的参数，并且去读取文档应该不是很慢
            queries = params_json['queries']
            # 先search, search本身有缓存
            search_results = self.googlescholar.call(params_json) # the result has sorted
            if search_results['success']:
                papers = search_results['results'] # List
                if ONLY_SCHOLAR:
                    # final_results = format_OnlyScholar_papers_to_str(papers)
                    return {
                        "success": True,
                        "results": papers,
                        "params": params_json
                    }
                
                # pdf parser and rerank
                # multi-thread
                query_list = []
                for paper in papers:
                    if paper['result'] is None: # trigger when scholar search is emtpy or other reasons
                        continue
                    temp = []
                    for idx, item in enumerate(paper['result']): 
                        if 'pdfUrl' in item:
                            temp.append({'paper_idx': idx, "url": item['pdfUrl']})
                    if temp:
                        query_list.append({"index": paper['index'], 'urls': temp})
                
                
                if query_list:  # if exsit valid url link
                    parsed_results = []
                    for query in query_list:
                        parsed_result = self.pdfparser.call(query)
                        if parsed_result['success']: # 这里只能是True
                            parsed_result['results']['query'] = queries[parsed_result['results']['index']]
                            parsed_results.append(parsed_result['results'])
                        else:
                            assert False, f"parsed failed:\n{parsed_result}"   
                    
                    # do chunk and rerank for each query
                    # TODO: multi-thread
                    # reranked_results = []
                    # for parsed_result in parsed_results:
                    #     reranked_result = self.rerank.call(parsed_result)
                    #     reranked_results.append(reranked_result['results'])
                    if not READPAGE:
                        with ThreadPool(max_workers=GTE_CONCURRENCY) as pool:
                            future = pool.map(self.rerank.call, parsed_results)
                            # reranked_results = list(tqdm(
                            #     future.result(),
                            #     total=len(parsed_results),
                            #     desc="GTE_Rerank"
                            # ))
                            reranked_results = list(future.result())
                            reranked_results = [item['results'] for item in reranked_results]

                        # 组装chunk以及对应paper
                        reranked_results = {item['index']: item for item in reranked_results}
                        for paper in papers:
                            query_index = paper['index']
                            if query_index in reranked_results:
                                reranked_results_per_paper = reranked_results[query_index]['urls']
                                for reranked_item in reranked_results_per_paper:
                                    p_index = reranked_item['paper_idx']
                                    assert reranked_item['url'] == paper['result'][p_index]['pdfUrl'], f"parsed result must be equal with "
                                    paper['result'][p_index]['selected_chunks'] = reranked_item.get('reranked_chunk', None)
                            else:
                                # pdb.set_trace()
                                # assert False, "loss some query for some reason, please check it"
                                # case 1: nopdf link
                                # case 2: paper['result'] is None when scholar search is emtpy or other reasons
                                if VERBOSE:
                                    print(f"query [{paper['payload']['query']}] has no pdfUrl.")

                    else:
                        parsed_results = {item['index']: item for item in parsed_results}
                        for paper in papers:
                            query_index = paper['index']
                            if query_index in parsed_results:
                                parsed_results_per_paper = parsed_results[query_index]['urls']
                                for parsed_item in parsed_results_per_paper:
                                    p_index = parsed_item['paper_idx']
                                    assert parsed_item['url'] == paper['result'][p_index]['pdfUrl'], f"parsed result must be equal with paper"
                                    if parsed_item['parsed_results'] is not None:
                                        paper['result'][p_index]['content'] = parsed_item['parsed_results']['content']
                    assert len(papers) == len(queries), f"the results after calling (papers) should be equal size with queries, but now papers num: {len(papers)}\nqueries:\n{queries}"

                return {
                    "success": True,
                    "results": papers,
                    "params": params_json
                }


            else: # search_results['success'] == False
                # search_results {'success': bool, 'error_message': str}
                # pdb.set_trace()
                # assert False, "empty query list for some reason, please check it"
                return {
                    'success': search_results['success'],
                    'error_message': search_results['error_message']
                }

        except Exception as e:
            import traceback
            import sys
            exc_type, exc_value, exc_traceback = sys.exc_info()
            tb_info = traceback.extract_tb(exc_traceback)
            filename, line, func, text = tb_info[-1]
            detailed_message = f"Google Scholar Error in file '{filename}', line {line}, in {func}: {str(e)}\nCode: {text}"
            # pdb.set_trace()
            if VERBOSE:
                print(detailed_message)
            return {
                "success": False,
                "error_message": f"Scholar Search failed: {str(e)}\ndetailed_message: {detailed_message}"
            }

    def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict:
        """Verify the parameters of the function call"""
        if isinstance(params, str):
            try:
                if strict_json:
                    params_json: dict = json.loads(params)
                else:
                    params_json: dict = json_loads(params)
            except json.decoder.JSONDecodeError:
                raise ValueError('Parameters must be formatted as a valid JSON!')
        else:
            params_json: dict = params
        if isinstance(self.parameters, list):
            for param in self.parameters:
                if 'required' in param and param['required']:
                    if param['name'] not in params_json:
                        raise ValueError('Parameters %s is required!' % param['name'])
        elif isinstance(self.parameters, dict):
            import jsonschema
            jsonschema.validate(instance=params_json, schema=self.parameters)
        else:
            raise ValueError
        return params_json

    def save_cache(self):
        """automatically save cache when the tool is deleted"""
        if hasattr(self, 'search_engine'):
            self.search_engine.save_cache()

    @property
    def function(self) -> dict:  # Bad naming. It should be `function_info`.
        return {
            'name': self.name,
            'description': self.description,
            'parameters': self.parameters,
        }

def format_papers_to_str(doc: dict) -> str:
    # doc {index, payload}
    content = "\n\n".join([
        (f"[Paper {idx + 1}]: \"{item['title']}\""
        + f"\nSnippet: {item['snippet']}"
        + f"\nPublished Year: {item.get('year', 'None')}"
        + f"\nCitation: {item.get('citedBy', 'None')}"
        + f"\nDetailed Chunks: {item.get('selected_chunks', 'None')}"
        ).strip()
        for idx, item in enumerate(doc['result'])
    ])

    return '```\n{}\n```'.format(content)

def format_papers_to_str_in_readpage(doc: dict) -> str:
    # doc {index, payload}
    content = [
        (f"Title: {item['title']}"
        + f"\nSnippet: {item['snippet']}"
        + f"\nContent: {' '.join(item.get('content', 'None').split(' ')[:50000])}"
        + f"\nPublished Info: {item.get('publicationInfo', 'None')}"
        + f"\nPublished Year: {item.get('year', 'None')}"
        + f"\nCitation: {item.get('citedBy', 'None')}"
        ).strip()
        for idx, item in enumerate(doc['result'])
    ]
    return content


if __name__ == "__main__":
    queries = {"queries": ["toolevo", "alphamath", "c-3po"]}

    tool = GoogleScholar()
    result = tool.call(queries)
    print()