from typing import Optional, Union, List, Dict

import os
import json
import jsonschema
import requests
import subprocess

from pebble import ThreadPool
from time import sleep
from ray.util import pdb
from tqdm import tqdm
from urllib.parse import urlparse
from concurrent.futures import TimeoutError

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import json_loads
from qwen_agent.settings import DEFAULT_WORKSPACE
from qwen_agent.tools.storage import KeyNotExistsError, Storage
# from qwen_agent.utils.utils import hash_sha256

# from qwen_agent.tools.multi_agent.utils import func_input_desc


CONCURRENCY = int(os.getenv('DOWNLOAD_CONCURRENCY', 5))
timeout = 100
# VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")
VERBOSE = True

def get_filename_from_url(url: str) -> str:
    """从 URL 中提取文件名"""
    parsed = urlparse(url)
    filename = os.path.basename(parsed.path)
    if not filename.endswith('.pdf'):
        filename += '.pdf'
    return filename

def aria2_pdf(args) -> Dict:
    """
    下载单个PDF文件
    
    Args:
        args: {"index": idx, "cached_name": cached_name, "url": url, "name": name}
    Returns:
        Dict: 下载结果
    """
    url, output_name = args['url'], args['name']
    cmd = [
        'aria2c',
        url,
        f'--out={output_name}',
        '--split=1',
        '--max-connection-per-server=1',
        '--max-tries=3',
        '--retry-wait=3',
        '--connect-timeout=60',
        '--console-log-level=error',
    ]
    try:
        result = subprocess.run(cmd, capture_output=True, text=True)
        args['success'] = result.returncode == 0
        args['error'] = result.stderr if result.returncode != 0 else None
        return args
    except Exception as e:
        args['success'] = False
        args['error'] = str(e)
        return args


class PDFDownload():
    name = "PDFDownload"
    def __init__(self, cfg: Optional[dict] = None):
        self.data_root = os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)
        self.db = Storage({'storage_root_path': self.data_root})

    def download(self, urls: List[str]):
        os.makedirs(self.data_root, exist_ok=True)
        # 准备下载任务
        urls_and_names = [
            (url, os.path.join(self.data_root, get_filename_from_url(url))) 
            for url in urls
        ]

        download_urls = []
        file_paths = []
        for idx, (url, name) in enumerate(urls_and_names):
            cached_name = url + "_cache"
            try:
                local_path = self.db.get(cached_name)
                file_paths.append({"index": idx, "cached_name": cached_name, "url": url, "name": name, "local_path": local_path})
            except KeyNotExistsError:
                download_urls.append({"index": idx, "cached_name": cached_name, "url": url, "name": name})

        outputs = []
        if download_urls:
            with ThreadPool(max_workers=CONCURRENCY) as pool:
                future = pool.map(aria2_pdf, download_urls, timeout=timeout)
                iterator = future.result()
                
                while True:
                    try:
                        result = next(iterator)
                        outputs.append(result)
                    except StopIteration:
                        break
                    except TimeoutError as error:
                        args = error.args[1]
                        args['success'] = False
                        args['error'] = f'Download timeout after {timeout} seconds'
                        outputs.append(args)
                    except Exception as error:
                        # 处理其他可能的错误
                        args = error.args[1]
                        args['success'] = False
                        args['error'] = str(error)
                        outputs.append(args)
            
            # cache search results:
            for item in outputs:
                if item['success']:
                    filtered_result = []
                    for i in item['result']:
                        if len((i['title'] + i['snippet']).strip()) != 0:
                            filtered_result.append(i)
                    item['result'] = None if len(filtered_result) == 0 else filtered_result

                if item['result'] is not None:
                    assert isinstance(item['result'], list) and len(item['result']) > 0, f"search result should be a non-empty result, but now {item['result']}"
                    self.db.put(item['cached_name'], json.dumps(item['result'], ensure_ascii=False, indent=2))

        final_results = sorted(file_paths + outputs, key=lambda x: x['index'])
        return final_results



if __name__ == "__main__":
    urls = [
        "https://arxiv.org/pdf/2410.06617",
        "https://link.springer.com/content/pdf/10.1186/1471-2105-9-493.pdf",
        "https://surface.syr.edu/cgi/viewcontent.cgi?article=1021&context=scied_etd"
    ]

    tool = PDFDownload()
    result = tool.download(urls)
    print()