# Copyright 2025 ZTE Corporation.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from typing import Dict, List, Tuple
from arxiv2text import arxiv_to_text
import os

import arxiv
import re

class CSArxivSearcher:
    """
    A tool to search for academic papers on arXiv.
    """
    def __init__(self, WORKSPACE_PATH, paper_text_max_length: int = 2000):
        self.WORKSPACE_PATH = WORKSPACE_PATH or os.getcwd()
        self.paper_text_max_length = paper_text_max_length
        """Initializes the Arxiv Searcher tool."""
        self.tool_json_schema = {
            "name": 'search_arxiv_paper',
            "description": 'Finds academic papers on arXiv and returns metadata like titles, authors, and abstracts.',
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "Keywords or the title to search for on arXiv."
                    },
                    "paper_ids": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "(Optional) A list of arXiv article IDs to which to limit the search."
                    },
                    "max_results": {
                        "type": "integer",
                        "description": "(Optional) The maximum number of search results to return. Defaults to 3.",
                        "default": 3
                    }
                },
                "required": ["query"]
            }
        }
        self.client = arxiv.Client()

    async def call_tool(self, arguments: dict, **kwargs) -> Tuple[List[Dict[str, str]], bool]:
        """
        Searches for academic papers on arXiv using a query string and optional paper IDs.

        Args:
            arguments (dict): A dictionary containing the arguments for the tool call.
                - query (str): The search query string.
                - paper_ids (List[str], optional): A list of specific arXiv paper IDs.
                - max_results (int, optional): Maximum number of results to return. Defaults to 5.

        Returns:
            List[Dict[str, str]]: A list of dictionaries, each containing metadata
            and extracted text for a paper. Returns an error message string on failure.
        """
        query = arguments["query"]
        paper_ids = arguments.get("paper_ids", [])
        max_results = arguments.get("max_results", 5)

        search_query = arxiv.Search(
            query=query,
            id_list=paper_ids,
            max_results=max_results,
        )
        search_results = self.client.results(search_query)

        papers_data = []
        for paper in search_results:
            paper_info = {
                "title": paper.title,
                "published_date": paper.published.date().isoformat(),
                "authors": [author.name for author in paper.authors],
                "entry_id": paper.entry_id,
                "summary": paper.summary,
                "pdf_url": paper.pdf_url,
            }
            
            # Attempt to extract text from the paper's PDF
            text = arxiv_to_text(paper_info["pdf_url"])

            # Truncate text to a manageable length
            paper_info['paper_text'] = text[:self.paper_text_max_length]
            papers_data.append(paper_info)
        
        return papers_data, True


class CSArxivDownloader:
    """
    A tool to download PDF files of academic papers from arXiv.
    """
    def __init__(self, WORKSPACE_PATH: str = None):
        self.WORKSPACE_PATH = WORKSPACE_PATH or os.getcwd()
        """Initializes the Arxiv Downloader tool."""
        self.tool_json_schema = {
            "name": 'search_and_download_arxiv_paper',
            "description": 'Searches for and downloads the PDF files of academic papers from arXiv.',
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "Keywords or the title of the paper to search for and download."
                    },
                    "paper_ids": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "(Optional) A list of arXiv article IDs to which to limit the search."
                    },
                    "max_results": {
                        "type": "integer",
                        "description": "(Optional) The maximum number of search results to return. Defaults to 3.",
                        "default": 3
                    },
                    "output_dir": {
                        "type": "string",
                        "description": "(Optional) The folder path where the downloaded PDFs will be saved. Defaults to the current directory.",
                        "default": "./"
                    }
                },
                "required": ["query"]
            }
        }
        self.client = arxiv.Client()

    def _sanitize_filename(self, filename: str) -> str:
        """Removes characters that are invalid for file names."""
        return re.sub(r'[\\/*?:"<>|]', "", filename)

    async def call_tool(self, arguments: dict, **kwargs) -> Tuple[str, bool]:
        """
        Downloads PDFs of academic papers from arXiv based on a search query.

        Returns:
            str: A status message indicating success or failure of the download operation.
        """
        query = arguments["query"]
        paper_ids = arguments.get("paper_ids", [])
        max_results = arguments.get("max_results", 1)
        output_dir = arguments.get("output_dir", "./")

        search_query = arxiv.Search(
            query=query,
            id_list=paper_ids,
            max_results=max_results,
        )
        search_results = self.client.results(search_query)

        downloaded_files = []
        for paper in search_results:
            # Sanitize the title to create a valid filename
            safe_filename = self._sanitize_filename(paper.title) + ".pdf"
            
            # Download the paper
            filepath = paper.download_pdf(dirpath=os.path.join(self.WORKSPACE_PATH, output_dir), filename=safe_filename)
            downloaded_files.append(filepath)

        if not downloaded_files:
            return "Warning: No papers found matching the query. Nothing was downloaded.", False

        return f"Successfully downloaded {len(downloaded_files)} paper(s) to {output_dir}.", True
