"""Refactored tools for agent interaction with search engine."""

import logging
from typing import List, Dict, Any, Optional, Set, Tuple, Callable
from langchain_core.tools import tool

from .search_engine import SearchEngine

# Simple think tool for strategic reflection
@tool
def think_tool(reflection: str) -> str:
    """Tool for recording thinking process.

    Use this tool to record your thinking process.

    Args:
        reflection: Your detailed thinking process.

    Returns:
        Confirmation that thinking process was recorded
    """
    return f"Thinking process recorded: {reflection}"



class SearchChecker:
    """Tracks search progress and validates results against ground truth."""
    
    def __init__(self, ground_truth_ids: Set[str], logger: Optional[logging.Logger] = None):
        """Initialize search checker.
        
        Args:
            ground_truth_ids: Set of document IDs that should be found
            logger: Logger instance
        """
        self.ground_truth_ids = ground_truth_ids
        self.logger = logger or logging.getLogger(__name__)
        
        # Overall tracking state
        self.found_ids: Set[str] = set()
        self.search_rounds: List[Dict[str, Any]] = []
        self.first_complete_round: Optional[int] = None
        
        # Research round tracking (outer loop)
        self.current_research_round: int = 0
        self.research_round_data: Dict[int, Dict[str, Any]] = {}
        
        # Engine-specific tracking by (research_round, engine_id)
        self.engine_research_data: Dict[Tuple[int, str], Dict[str, Any]] = {}
    
    def start_new_research_round(self):
        """Start a new research round (outer loop)."""
        self.current_research_round += 1
        self.research_round_data[self.current_research_round] = {
            'round': self.current_research_round,
            'engines_used': set(),
            'queries': [],
            'found_ids_this_round': set(),
            'new_target_ids_this_round': set(),
            'search_rounds': []
        }
        self.logger.info(f"Started research round {self.current_research_round}")
    
    def update(self, query: str, found_ids: Set[str], round_type: str = "search", engine_id: str = "engine_0"):
        """Update tracking with new search results.
        
        Args:
            query: The search query or page number
            found_ids: Set of document IDs found in this round
            round_type: Type of round (search or next_page)
            engine_id: ID of the search engine
        """
        # Find new target IDs in this round
        new_target_ids = found_ids.intersection(self.ground_truth_ids) - self.found_ids
        
        # Update overall state
        self.found_ids.update(new_target_ids)
        
        # Update engine-specific tracking for (research_round, engine_id) pair
        engine_key = (self.current_research_round, engine_id)
        if engine_key not in self.engine_research_data:
            self.engine_research_data[engine_key] = {
                'research_round': self.current_research_round,
                'engine_id': engine_id,
                'search_rounds': [],
                'found_ids': set(),
                'new_target_ids': set(),
                'queries': []
            }
        
        engine_data = self.engine_research_data[engine_key]
        
        # Local search round number for this (research_round, engine_id) pair
        local_search_round = len(engine_data['search_rounds']) + 1
        
        # Record search round with local numbering
        round_data = {
            'global_round': len(self.search_rounds) + 1,  # Keep global for overall tracking
            'local_round': local_search_round,  # Local to this (research_round, engine_id)
            'research_round': self.current_research_round,
            'engine_id': engine_id,
            'type': round_type,
            'query': query,
            'found_ids': list(found_ids),
            'new_target_ids': list(new_target_ids),
            'total_target_found': len(self.found_ids)
        }
        self.search_rounds.append(round_data)
        
        # Update engine-specific data
        engine_data['search_rounds'].append(round_data)
        engine_data['found_ids'].update(found_ids)
        engine_data['new_target_ids'].update(new_target_ids)
        engine_data['queries'].append(query)
        
        # Update research round tracking
        if self.current_research_round > 0:
            research_data = self.research_round_data[self.current_research_round]
            research_data['engines_used'].add(engine_id)
            research_data['queries'].append(query)
            research_data['found_ids_this_round'].update(found_ids)
            research_data['new_target_ids_this_round'].update(new_target_ids)
            research_data['search_rounds'].append(round_data)
        
        # Log search progress with local search round number
        total_targets = len(self.ground_truth_ids) if self.ground_truth_ids else 0
        current_found = len(self.found_ids)
        proportion_found = current_found / total_targets if total_targets > 0 else 1.0
        
        self.logger.info(f"Search Round {local_search_round} [Research Round {self.current_research_round}][{engine_id}] ({round_type}): "
                        f"Found {len(new_target_ids)} new target documents. "
                        f"Total: {current_found}/{total_targets} ({proportion_found:.1%})")
        
        # Check if all targets found (using global round for this milestone)
        if self.first_complete_round is None and self.is_complete():
            self.first_complete_round = len(self.search_rounds)
            self.logger.info(f"All target documents found in global search round {self.first_complete_round}")
    
    def is_complete(self) -> bool:
        """Check if all ground truth documents have been found."""
        return self.ground_truth_ids.issubset(self.found_ids)
    
    def get_missing_ids(self) -> Set[str]:
        """Get the set of ground truth IDs not yet found."""
        return self.ground_truth_ids - self.found_ids
    
    def get_summary(self) -> Dict[str, Any]:
        """Get summary of search progress.
        
        Returns:
            Dictionary with search statistics
        """
        # Calculate proportions per round
        proportions_per_round = []
        cumulative_found = 0
        total_target = len(self.ground_truth_ids) if self.ground_truth_ids else 0
        
        for round_data in self.search_rounds:
            cumulative_found += len(round_data['new_target_ids'])
            proportion = cumulative_found / total_target if total_target > 0 else 1.0
            proportions_per_round.append({
                'round': round_data['global_round'],
                'cumulative_found': cumulative_found,
                'proportion': proportion
            })
        
        return {
            'total_rounds': len(self.search_rounds),
            'first_complete_round': self.first_complete_round,
            'target_documents': len(self.ground_truth_ids),
            'found_documents': len(self.found_ids),
            'missing_documents': len(self.get_missing_ids()),
            'completion_rate': len(self.found_ids) / len(self.ground_truth_ids) if self.ground_truth_ids else 1.0,
            'proportions_per_round': proportions_per_round,
            'rounds': self.search_rounds
        }
    
    def get_detailed_tracking_data(self) -> Dict[str, Any]:
        """Get detailed search tracking data for saving.
        
        Returns:
            Dictionary with detailed tracking information including queries and found IDs
        """
        # Convert sets to lists for JSON serialization
        research_round_data_serializable = {}
        for round_num, data in self.research_round_data.items():
            research_round_data_serializable[round_num] = {
                'round': data['round'],
                'engines_used': list(data['engines_used']),
                'queries': data['queries'],
                'found_ids_this_round': list(data['found_ids_this_round']),
                'new_target_ids_this_round': list(data['new_target_ids_this_round']),
                'search_rounds': data['search_rounds']
            }
        
        # Convert engine research data for serialization
        engine_research_data_serializable = {}
        for (research_round, engine_id), data in self.engine_research_data.items():
            key = f"r{research_round}_e{engine_id}"
            engine_research_data_serializable[key] = {
                'research_round': data['research_round'],
                'engine_id': data['engine_id'],
                'search_rounds': data['search_rounds'],
                'found_ids': list(data['found_ids']),
                'new_target_ids': list(data['new_target_ids']),
                'queries': data['queries']
            }
        
        return {
            'ground_truth_ids': list(self.ground_truth_ids),
            'found_ids': list(self.found_ids),
            'missing_ids': list(self.get_missing_ids()),
            'search_rounds': self.search_rounds,
            'first_complete_round': self.first_complete_round,
            'completion_rate': len(self.found_ids) / len(self.ground_truth_ids) if self.ground_truth_ids else 1.0,
            'total_rounds': len(self.search_rounds),
            'target_documents': len(self.ground_truth_ids),
            'found_documents': len(self.found_ids),
            'current_research_round': self.current_research_round,
            'research_round_data': research_round_data_serializable,
            'engine_research_data': engine_research_data_serializable
        }
    
    def get_research_round_summary(self, research_round: int) -> Dict[str, Any]:
        """Get summary for a specific research round.
        
        Args:
            research_round: Research round number
            
        Returns:
            Summary data for the research round
        """
        if research_round not in self.research_round_data:
            return {}
        
        data = self.research_round_data[research_round]
        return {
            'round': research_round,
            'engines_used': list(data['engines_used']),
            'total_queries': len(data['queries']),
            'found_ids_count': len(data['found_ids_this_round']),
            'new_target_ids_count': len(data['new_target_ids_this_round']),
            'search_rounds_count': len(data['search_rounds'])
        }
    
    def get_engine_research_summary(self, research_round: int, engine_id: str) -> Dict[str, Any]:
        """Get summary for a specific (research_round, engine_id) pair.
        
        Args:
            research_round: Research round number
            engine_id: Engine ID
            
        Returns:
            Summary data for the engine in that research round
        """
        engine_key = (research_round, engine_id)
        if engine_key not in self.engine_research_data:
            return {}
        
        data = self.engine_research_data[engine_key]
        return {
            'research_round': research_round,
            'engine_id': engine_id,
            'total_queries': len(data['queries']),
            'found_ids_count': len(data['found_ids']),
            'new_target_ids_count': len(data['new_target_ids']),
            'search_rounds_count': len(data['search_rounds'])
        }
    
    def get_formatted_summary(self) -> str:
        """Get formatted summary string.
        
        Returns:
            Formatted summary of search progress
        """
        summary = self.get_summary()
        
        lines = [
            "Search Summary:",
            f"  Total rounds: {summary['total_rounds']}",
            f"  Target documents: {summary['target_documents']}",
            f"  Found documents: {summary['found_documents']}",
            f"  Completion rate: {summary['completion_rate']:.1%}",
        ]
        
        if summary['first_complete_round']:
            lines.append(f"  All targets found in round: {summary['first_complete_round']}")
        
        if summary['missing_documents'] > 0:
            lines.append(f"  Missing documents: {summary['missing_documents']}")
        
        # Add round details
        lines.append("\nRound Details:")
        for round_data in summary['rounds'][:10]:  # Show first 10 rounds
            lines.append(
                "  Round {} ({}): Query='{}...', New targets found: {}".format(
                    round_data['global_round'], 
                    round_data['type'], 
                    round_data['query'][:50], 
                    len(round_data['new_target_ids'])
                )
            )
        
        if len(summary['rounds']) > 10:
            lines.append(f"  ... and {len(summary['rounds']) - 10} more rounds")
        
        return "\n".join(lines)


class SearchTools:
    """Factory for creating search tools for agents."""
    
    @staticmethod
    def create_search_tools(
        search_engine: SearchEngine|List[SearchEngine],
        problem_id: str,
        ground_truth_ids: Optional[Set[str]] = None,
        logger: Optional[logging.Logger] = None,
        with_backtracking: bool = False,
        with_think: bool = False,
        with_explore: bool = False,
    ) -> Tuple[List[Callable]|List[List[Callable]], SearchChecker]:
        """Create search tools and checker for a problem.
        
        Args:
            search_engine: SearchEngine instance or list of SearchEngine instances
            problem_id: ID of the problem being evaluated
            ground_truth_ids: Set of document IDs that should be found
            logger: Logger instance
            with_backtracking: Whether to include the backtracking tool
            with_think: Whether to include the think tool
            with_explore: Whether to include the explore tool
        Returns:
            For single engine: Tuple of (list of tool functions, SearchChecker instance)
            For multiple engines: Tuple of (list of search tool groups, SearchChecker instance)
        """
        logger = logger or logging.getLogger(__name__)
        
        # Validate inputs
        if search_engine is None:
            raise ValueError("search_engine cannot be None")
        if not problem_id:
            raise ValueError("problem_id cannot be empty")
        
        # Initialize search checker (shared across all engines)
        checker = SearchChecker(ground_truth_ids, logger)
        
        # Handle single search engine case
        if not isinstance(search_engine, list):
            if not hasattr(search_engine, 'search'):
                raise ValueError("search_engine must have a 'search' method")
            return SearchTools._create_single_engine_tools(search_engine, problem_id, checker, logger, with_backtracking, with_think, with_explore)
        
        # Handle multiple search engines case
        if not search_engine:
            raise ValueError("search_engine list cannot be empty")
        
        search_tool_groups = []
        for i, engine in enumerate(search_engine):
            if not hasattr(engine, 'search'):
                raise ValueError(f"search_engine[{i}] must have a 'search' method")
            tools, _ = SearchTools._create_single_engine_tools(engine, problem_id, checker, logger, with_backtracking, with_think, with_explore)
            search_tool_groups.append(tools)
        
        return search_tool_groups, checker
    
    @staticmethod
    def _create_single_engine_tools(
        search_engine: SearchEngine,
        problem_id: str,
        checker: SearchChecker,
        logger: logging.Logger,
        with_backtracking: bool = False,
        with_think: bool = False,
        with_explore: bool = False
    ) -> Tuple[List[Callable], SearchChecker]:
        """Create search tools for a single search engine.
        
        Args:
            search_engine: SearchEngine instance
            problem_id: Problem ID for logging
            checker: SearchChecker instance
            logger: Logger instance
            with_backtracking: Whether to include the backtracking tool
            with_think: Whether to include the think tool
            with_explore: Whether to include the explore tool
            
        Returns:
            Tuple of (list of tool functions, SearchChecker instance)
        """
        
        @tool
        def search_information(query: str) -> str:
            """Search for information based on a natural language query.
            
            Returns the most relevant documents from the database.
            
            Args:
                query: Natural language search query
                
            Returns:
                Formatted search results with document IDs and content
            """
            logger.info(f"[{problem_id}][{search_engine.engine_id}] Searching for: {query}")
            
            # Execute search - let exceptions propagate for debugging
            results, found_ids = search_engine.search(query)
            
            # Get first page of results
            formatted_results, page_ids = search_engine.get_first_page()
            
            # Update checker
            checker.update(query, page_ids, "search", search_engine.engine_id)
            
            return formatted_results
        
        @tool
        def next_page() -> str:
            """Get the next page of search results from the last query.
            
            Returns:
                Next page of search results or message if no more pages
            """
            logger.info(f"[{problem_id}][{search_engine.engine_id}] Getting next page")
            
            # Get next page - let exceptions propagate for debugging
            formatted_results, page_ids = search_engine.get_next_page()
            
            # Update checker
            checker.update("next_page", page_ids, "next_page", search_engine.engine_id)
            
            return formatted_results
        
        # Create basic tools list
        tools = [search_information, next_page]
        
        # Add backtracking tool if requested
        if with_backtracking:
            @tool
            def revisit(research_topic: str, reasoning: str, new_query: str) -> str:
                """Revisit a previous search topic and conduct a new search with improved query.
                
                Use this tool when you realize you need to revisit a previous search topic with a
                better or different search strategy based on new insights.
                
                Args:
                    search_topic: The search topic you want to revisit to and explore further
                    reasoning: Explanation of why you need to revisit this topic and what insights 
                             led to this decision. Describe what query changes you want to make.
                    new_query: The new, improved search query to use for this search topic
                    
                Returns:
                    Formatted search results from the new query
                """
                logger.info(f"[{problem_id}][{search_engine.engine_id}] Backtracking to topic: '{research_topic}'")
                logger.info(f"[{problem_id}][{search_engine.engine_id}] Reasoning: {reasoning}")
                logger.info(f"[{problem_id}][{search_engine.engine_id}] New query: {new_query}")
                
                # Execute search with the new query - let exceptions propagate for debugging
                results, found_ids = search_engine.search(new_query)
                
                # Get first page of results
                formatted_results, page_ids = search_engine.get_first_page()
                
                # Update checker with backtracking context
                backtrack_query = f"BACKTRACK[{research_topic}]: {new_query}"
                checker.update(backtrack_query, page_ids, "backtrack", search_engine.engine_id)
                
                return formatted_results
            
            tools.append(revisit)
        
        # Add explore tool if requested
        if with_explore:
            @tool
            def explore(new_explore_topic: str, reasoning: str, query: str) -> str:
                """Explore a completely new search topic that hasn't been investigated yet.
                
                Use this tool when you need to explore a new angle, concept, or area that you
                haven't searched for previously. This is different from revisiting - it's for
                expanding into entirely new territories.
                
                Args:
                    new_explore_topic: The new search topic you want to explore
                    reasoning: Explanation of why you need to explore this new topic and what 
                             insights you hope to gain. Describe how this connects to the problem.
                    query: The search query to use for exploring this new search topic
                    
                Returns:
                    Formatted search results for the new exploration
                """
                logger.info(f"[{problem_id}][{search_engine.engine_id}] Exploring new topic: '{new_explore_topic}'")
                logger.info(f"[{problem_id}][{search_engine.engine_id}] Reasoning: {reasoning}")
                logger.info(f"[{problem_id}][{search_engine.engine_id}] Query: {query}")
                
                # Execute search with the new exploration query
                results, found_ids = search_engine.search(query)
                
                # Get first page of results
                formatted_results, page_ids = search_engine.get_first_page()
                
                # Update checker with exploration context
                explore_query = f"EXPLORE[{new_explore_topic}]: {query}"
                checker.update(explore_query, page_ids, "explore", search_engine.engine_id)
                
                return formatted_results
            
            tools.append(explore)
        
        # Add think tool if requested
        if with_think:
            # Import the think_tool from deep_research_utils
            from ..agents.deep_research_utils import think_tool
            tools.append(think_tool)
            logger.info(f"✅ Added think_tool to agent. Total tools: {len(tools)}")
            logger.info(f"Tool names: {[tool.name for tool in tools]}")
        
        return tools, checker
    
    @staticmethod
    def clear_search_engine_caches(search_engines: List[Any]):
        """Clear caches for multiple search engines.
        
        Args:
            search_engines: List of SearchEngine instances
        """
        for engine in search_engines:
            if hasattr(engine, 'current_search_state'):
                engine.current_search_state = None
    
    @staticmethod
    def load_ground_truth_ids(
        problem: dict[str, Any]
    ) -> Set[str]:
        """Load ground truth document IDs for a problem.
        
        Args:
            problem: Problem data
            
        Returns:
            Set of document IDs that should be found for this problem
        """
        try:
            mapping = problem['document_ids']
        except Exception as e:
            raise ValueError(f"Cannot load ground truth document IDs, may using a different dataset format: {e}")
        
        assert isinstance(mapping, list), "Ground truth document IDs must be a list"

        return set(mapping)
            
