import arxiv
import json
import re
import asyncio
from typing import Dict, Any, Union, Optional, List

# Assuming Tool_node is available from the HeGFlow framework as per the provided context.
# from HeGFlow.graph.tool_node import Tool_node


class ArxivSearcher(Tool_node):
    """
    A tool for searching academic papers on arXiv.org.
    It can retrieve paper details like title, authors, abstract, and links based on a query.
    """
    def __init__(self):
        super().__init__(
            name="ArxivSearcher",
            description="Searches for academic papers on arXiv.org and returns their key details. Accepts a search query and optional 'max_results' parameter."
        )

    async def _execute_tool(self, task: str) -> str:
        """
        Executes an arXiv search based on the provided `task` string.
        The `task` can be a simple search term (e.g., "Transformer architectures")
        or a more structured command specifying parameters (e.g., "query='Large Language Models', max_results=3").
        
        Returns:
            A JSON string representing a list of dictionaries, where each dictionary contains
            details of a found paper, or an error message string if the search fails.
        """
        query: str = ""
        max_results: int = 5 # Default number of results to retrieve

        try:
            # Attempt to parse specific parameters from the task string
            query_match = re.search(r"query=['\"]([^'\"]+)['\"]", task)
            max_results_match = re.search(r"max_results=(\d+)", task)

            if query_match:
                query = query_match.group(1)
            else:
                # If 'query=' is not found, assume the entire task string is the query
                query = task.strip()
            
            if max_results_match:
                max_results = int(max_results_match.group(1))

        except Exception:
            # Fallback if parsing fails, treat the entire task as the query and use default max_results
            query = task.strip()
            max_results = 5
        
        if not query:
            return "Error: No search query provided. Please specify a query, e.g., 'Artificial Intelligence' or 'query=\"Quantum Computing\", max_results=2'."

        try:
            client = arxiv.Client()
            search = arxiv.Search(
                query=query,
                max_results=max_results,
                sort_by=arxiv.SortCriterion.SubmittedDate, # Sort by submission date
                sort_order=arxiv.SortOrder.Descending     # Get the latest papers first
            )

            results_list: List[Dict[str, Any]] = []
            # arxiv.Client().results returns an async iterator
            async for r in client.results(search): 
                paper_info = {
                    "title": r.title,
                    "authors": [author.name for author in r.authors],
                    "summary": r.summary,
                    "published_date": r.published.isoformat(),
                    "pdf_url": r.pdf_url,
                    "arxiv_url": r.entry_id
                }
                results_list.append(paper_info)
            
            return json.dumps(results_list, indent=2)

        except Exception as e:
            return f"Error during arXiv search: {type(e).__name__}: {e}"

    def _format_result_to_natural_language(self, raw_result: str, task_description: str) -> str:
        """
        Converts the raw JSON string output from _execute_tool into a human-readable string.
        """
        try:
            parsed_results: List[Dict[str, Any]] = json.loads(raw_result)
        except json.JSONDecodeError:
            # If the raw_result is not valid JSON, it's likely an error message
            return f"Failed to interpret search results for '{task_description}'. Raw output: {raw_result}"

        if not parsed_results:
            return f"No papers found on arXiv for query: '{task_description}'."

        formatted_output_lines = [f"arXiv search for '{task_description}' found {len(parsed_results)} papers:"]
        for i, paper in enumerate(parsed_results):
            formatted_output_lines.append(f"\n--- Paper {i+1} ---")
            formatted_output_lines.append(f"Title: {paper.get('title', 'N/A')}")
            formatted_output_lines.append(f"Authors: {', '.join(paper.get('authors', ['N/A']))}")
            
            summary = paper.get('summary', 'N/A')
            # Clean up and truncate the summary for better readability
            clean_summary = summary.replace('\n', ' ').strip()
            if len(clean_summary) > 300: # Limit summary to 300 characters
                clean_summary = clean_summary[:300].strip() + "..."
            formatted_output_lines.append(f"Summary: {clean_summary}")
            
            published_date_str = paper.get('published_date', 'N/A')
            if published_date_str != 'N/A':
                published_date_str = published_date_str.split('T')[0] # Extract only the date part
            formatted_output_lines.append(f"Published: {published_date_str}")
            
            formatted_output_lines.append(f"PDF Link: {paper.get('pdf_url', 'N/A')}")
            formatted_output_lines.append(f"ArXiv Page: {paper.get('arxiv_url', 'N/A')}")

        return "\n".join(formatted_output_lines)