"""Problem evaluator that runs agents on specific problems."""

import os
import time
import logging
from typing import Dict, Any, Optional

from .agent_factory import AgentFactory, AgentType
from ..config import UnifiedBatchConfig, EvaluationTask, EvaluationResult
from ..search import SearchEngine, SearchTools
from ..logging import JSONLogger, MetricsCollector

# Import for LLM-based answer parsing
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage


ANSWER_JUDGE_METHOD = 'llm_only'

class ProblemEvaluator:
    """Evaluates problems using configured agents."""
    
    def __init__(
        self,
        config: UnifiedBatchConfig,
        search_engine: SearchEngine,
        logger: Optional[logging.Logger] = None
    ):
        """Initialize problem evaluator.
        
        Args:
            config: Batch configuration
            search_engine: Search engine instance
            logger: Logger instance
        """
        self.config = config
        self.search_engine = search_engine
        self.logger = logger or logging.getLogger(__name__)
        
        # Initialize components
        self.json_logger = JSONLogger(
            output_dir=config.evaluation.output_dir,
            experiment_name=config.experiment_name
        )
        self.metrics_collector = MetricsCollector()
    
    def evaluate_task(self, task: EvaluationTask) -> EvaluationResult:
        """Evaluate a single task.
        
        Args:
            task: Evaluation task to process
            
        Returns:
            Evaluation result
        """
        start_time = time.time()
        self.logger.info(f"Starting evaluation for task: {task.task_id}")
        
        try:
            # Extract problem data
            problem = task.problem_data
            question = problem.get('question', '')
            ground_truth = problem.get('final_answer', '')
            problem_id = problem.get('question_id', task.task_id)
            
            # Get agent type from task config (default to REACT)
            agent_type_str = task.metadata['setting'].agent_type
            
            # Load ground truth document IDs
            ground_truth_ids = SearchTools.load_ground_truth_ids(
                problem
            )
            
            # Create search engines (multiple for deep research, single for react)
            if agent_type_str.lower() == 'deep_research':
                # Get max_concurrent_research_units from deep_research_config
                try:
                    evaluation_params = task.metadata['setting'].__dict__
                    dr_config = evaluation_params['deep_research_config']
                    max_concurrent = dr_config['max_concurrent_research_units']
                except (KeyError, IndexError, AttributeError) as e:
                    raise ValueError(f"Failed to access deep research config from task: {e}")
                
                if max_concurrent <= 0:
                    raise ValueError(f"max_concurrent_research_units must be positive, got {max_concurrent}")
                
                # Create multiple search engines with unique IDs
                if not self.search_engine:
                    raise ValueError("Primary search engine is None, cannot create multiple engines")
                
                search_engines = []
                for i in range(max_concurrent):
                    try:
                        engine = SearchEngine(
                            chromadb_path=self.search_engine.chromadb_path,
                            collection_name=self.search_engine.collection_name,
                            embedding_model=self.search_engine.embedding_model,
                            cache_size=self.search_engine.cache_size,
                            engine_id=f"engine_{i}",
                            logger=self.logger
                        )
                        search_engines.append(engine)
                    except Exception as e:
                        raise ValueError(f"Failed to create search engine {i}: {e}") from e
                
                # Create search tools for multiple engines
                search_tools, search_checker = SearchTools.create_search_tools(
                    search_engines,
                    problem_id,
                    ground_truth_ids,
                    self.logger
                )
            else:
                # Determine if backtracking is needed
                with_backtracking = agent_type_str.lower() == 'react_with_backtracking'
                # Determine if think tool is needed
                with_think = agent_type_str.lower() == 'react_with_think'
                # Determine if explore tool is needed
                with_explore = agent_type_str.lower() in ['react_with_explore', 'react_with_explore_revisit']
                # For explore_revisit, we need both explore and revisit tools
                if agent_type_str.lower() == 'react_with_explore_revisit':
                    with_backtracking = True  # Enable revisit tool as well
                
                # Create single search engine for react agent
                search_tools, search_checker = SearchTools.create_search_tools(
                    self.search_engine,
                    problem_id,
                    ground_truth_ids,
                    self.logger,
                    with_backtracking=with_backtracking,
                    with_think=with_think,
                    with_explore=with_explore
                )
            try:
                agent_type = AgentType(agent_type_str.lower())
            except ValueError:
                raise ValueError(f"Unknown agent type '{agent_type_str}'")
            
            # Get system prompt
            system_prompt = self.config.prompts.system_prompt
            if not system_prompt and agent_type != AgentType.DEEP_RESEARCH:
                raise ValueError("System prompt is required for REACT agent")

            # Get agent stop configuration from task metadata
            setting = task.metadata.get('setting')
            agent_stop_type = getattr(setting, 'agent_stop_type', 'default') if setting else 'default'
            agent_stop_kwargs = getattr(setting, 'agent_stop_kwargs', {}) if setting else {}

            # Create agent
            if agent_type_str.lower() == 'deep_research':
                # For deep research, pass the list of search engines
                dr_config = task.metadata['setting'].deep_research_config
                agent = AgentFactory.create_agent(
                    model_config=task.model_config.__dict__,
                    system_prompt=system_prompt,
                    search_tools=search_tools,
                    agent_type=agent_type,
                    search_engine=self.search_engine,
                    search_checker=search_checker,
                    search_engines=search_engines,
                    logger=self.logger,
                    dr_config=dr_config,
                    agent_stop_type=agent_stop_type,
                    agent_stop_kwargs=agent_stop_kwargs
                )
            else:
                # For react agent, no need for multiple engines
                agent = AgentFactory.create_agent(
                    model_config=task.model_config.__dict__,
                    system_prompt=system_prompt,
                    search_tools=search_tools,
                    agent_type=agent_type,
                    search_engine=self.search_engine,
                    search_checker=search_checker,
                    logger=self.logger,
                    agent_stop_type=agent_stop_type,
                    agent_stop_kwargs=agent_stop_kwargs
                )

            

            # Run agent based on type
            self.logger.debug(f"Running {agent_type.value} agent for problem: {problem_id}")
            
            if agent_type == AgentType.DEEP_RESEARCH:
                # Deep research workflow takes a question string directly
                import asyncio
                response = asyncio.run(agent.run(question))
                
                # Convert to standard response format for compatibility
                final_report = response.get("final_report", "No report generated")
                response = {
                    "messages": [{"content": final_report,
                                  "role": "assistant"}],
                    # Expose intermediates for downstream saving
                    "deep_research": {
                        "final_report_prompt": response.get("final_report_prompt"),
                        "research_topics_by_round": response.get("research_topics_by_round", [])
                    }
                }
            else:
                # Standard React agent invocation
                # import pdb; pdb.set_trace()
                input_message = {"role": "user", "content": question}
                # Set the recursion limit \approx # nodes in the graph * the interaction rounds, we use 7 as a safe upper bound
                recursion_limit = agent_stop_kwargs.get("recursion_limit", 50) * 7
                response = agent.invoke(
                    {"messages": [input_message]},
                    {"recursion_limit": recursion_limit}
                )
            
            # Process response
            # import pdb; pdb.set_trace()
            final_answer, follow_format = self._extract_final_answer(response, task)
            is_correct = self._check_answer(final_answer, ground_truth)
            # Calculate metrics
            tool_calls = self._count_tool_calls(response)
            total_tokens = self._count_tokens(response)

            
            
            # Log evaluation (this now includes full conversation)
            log_data = self.json_logger.log_evaluation(
                task_id=task.task_id,
                model_name=task.model_config.name,
                problem_data=problem,
                response=response,
                metrics={
                    'is_correct': is_correct,
                    'follow_format': follow_format,
                    'tool_calls': tool_calls,
                    'total_tokens': total_tokens,
                    'search_summary': search_checker.get_summary(),
                    'detailed_search_tracking': search_checker.get_detailed_tracking_data()
                },
                config=task.model_config.__dict__
            )
            
            # Record metrics
            duration = time.time() - start_time
            self.metrics_collector.record_evaluation_metrics(
                task_id=task.task_id,
                model_name=task.model_config.name,
                is_correct=is_correct,
                duration=duration,
                tokens_used=total_tokens,
                tool_calls=tool_calls,
                follow_format=follow_format
            )
            
            # Create result (include full conversation from log_data)
            result = EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=True,
                answer=final_answer,
                ground_truth=ground_truth,
                is_correct=is_correct,
                metrics={
                    'follow_format': follow_format,
                    'tool_calls': tool_calls,
                    'total_tokens': total_tokens,
                    'search_complete': search_checker.is_complete(),
                    'search_rounds': len(search_checker.search_rounds),
                    'full_conversation': log_data['response'].get('full_conversation', []),  # Include full conversation
                    'detailed_search_tracking': search_checker.get_detailed_tracking_data(),  # Include detailed search tracking
                    # Surface deep research intermediates (if any) into metrics for saving
                    'deep_research': response.get('deep_research', {}) if agent_type == AgentType.DEEP_RESEARCH else {}
                },
                duration=duration,
                attempt=task.attempt
            )
            
            self.logger.info(f"Completed task {task.task_id}: correct={is_correct}")
            return result
            
        except Exception as e:
            self.logger.error(f"Error evaluating task {task.task_id}: {e}")
            
            return EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=False,
                error=str(e),
                ground_truth=task.problem_data.get('final_answer'),
                duration=time.time() - start_time,
                attempt=task.attempt
            )
    
    def _extract_final_answer(self, response: Dict[str, Any], task: EvaluationTask = None) -> Optional[str]:
        """Extract the final answer from agent response using configured judge method.
        
        Args:
            response: Agent response dictionary
            task: Evaluation task (to get answer judge configuration)
            
        Returns:
            Extracted answer or None
        """
        
        # Get answer judge method from task configuration
        answer_judge = ANSWER_JUDGE_METHOD  # default
        if task and task.metadata.get('setting'):
            setting = task.metadata['setting']
            agent_stop_kwargs = getattr(setting, 'agent_stop_kwargs', {})
            answer_judge = agent_stop_kwargs.get('answer_judge', ANSWER_JUDGE_METHOD)
        
        if answer_judge == 'parsing_only':
            return self._parse_final_answer(response)
        elif answer_judge == 'llm_only':
            return self._llm_extract_final_answer(response)
        else:  # parsing_and_llm
            parsed = self._parse_final_answer(response)
            if parsed:
                return parsed, True
            return self._llm_extract_final_answer(response)
    
    def _parse_final_answer(self, response: Dict[str, Any]) -> Optional[str]:
        """Parse final answer looking for #### pattern."""
        messages = response.get('messages', [])
        
        for message in reversed(messages):
            if hasattr(message, 'content') or 'content' in message:
                content = message.content if hasattr(message, 'content') else message['content']
                if isinstance(content, str) and '####' in content:
                    # Extract answer after ####
                    parts = content.split('####')
                    if len(parts) > 1:
                        answer = parts[-1].strip()
                        return answer, True if answer else None, True
                    else:
                        self.logger.warning("No #### found in the response.")
                        return content, False
        
        return None, False
    
    def _llm_extract_final_answer(self, response: Dict[str, Any]) -> Optional[str]:
        """Use LLM to extract final answer from response."""
        try:
            # Create a simple extraction tool
            @tool
            def extract_final_answer(final_answer: str, follow_format: bool) -> dict:
                """Extract the final answer from the agent's response.
                
                Args:
                    final_answer: The final answer (including "I don't know" if applicable)
                    follow_format: Whether the answer is properly placed after #### or ****
                """
                return {"final_answer": final_answer, "follow_format": follow_format}
            
            # Initialize extraction LLM
            extraction_llm = ChatOpenAI(
                model="gpt-5-mini",
                temperature=0.1,
                api_key=os.getenv("OPENAI_API_KEY")
            ).bind_tools([extract_final_answer])
            
            # Get the last message for context
            messages = response.get('messages', [])
            if not messages:
                return None, False
                
            last_msg = messages[-1]
            content = ""
            if hasattr(last_msg, 'content'):
                content = last_msg.content
            elif isinstance(last_msg, dict) and 'content' in last_msg:
                content = last_msg['content']
            else:
                return None, False
            
            extraction_prompt = f"""Look at this agent's response and extract the final answer.

Response:
{content}

Please:
1. Check if the final answer is properly formatted after ####
2. Extract the final answer regardless of formatting
3. If the answer is "I don't know", simply return "I don't know"

Use the extract_final_answer tool with:
- final_answer: the actual answer (or "I don't know")
- follow_format: True if answer appears after ####, False otherwise"""
            
            extraction_response = extraction_llm.invoke([HumanMessage(content=extraction_prompt)])
            
            if extraction_response.tool_calls:
                tool_call = extraction_response.tool_calls[0]
                if tool_call['name'] == 'extract_final_answer':
                    result = tool_call['args']
                    final_answer = result.get('final_answer')
                    follow_format = result.get('follow_format', False)
                    
                    # Log format compliance for debugging
                    self.logger.debug(f"Answer extraction - follow_format: {follow_format}, answer: {final_answer}")
                    
                    return final_answer, follow_format
            
            return None, False
            
        except Exception as e:
            self.logger.warning(f"LLM-based answer extraction failed: {e}")
            return None, False
    
    def _check_answer(self, answer: Optional[str], ground_truth: str) -> bool:
        """Check if the answer matches ground truth using LLM-based comparison.
        
        Args:
            answer: Extracted answer
            ground_truth: Ground truth answer
            
        Returns:
            True if answers match according to GPT-4o-mini
        """
        
        if not answer:
            return False
        
        try:
            return self._llm_based_answer_check(answer, ground_truth)
        except Exception as e:
            self.logger.warning(f"LLM-based answer checking failed: {e}.")
            raise e
    
    def _llm_based_answer_check(self, answer: str, ground_truth: str) -> bool:
        """Use GPT-5-mini to compare answers with tool-based response.
        
        Args:
            answer: The extracted answer from the agent
            ground_truth: The correct ground truth answer
            
        Returns:
            True if the LLM determines the answers are equivalent
        """
        # Create the answer comparison tool
        @tool
        def answer_comparison_result(is_correct: bool, reasoning: str) -> dict:
            """Report whether the given answer is correct compared to the ground truth.
            
            Args:
                is_correct: True if the given answer is mathematically equivalent to the ground truth
                reasoning: Brief explanation of the comparison
            """
            return {"is_correct": is_correct, "reasoning": reasoning}
        
        # Initialize GPT-5-nano
        llm = ChatOpenAI(
            model="gpt-5-mini",
            temperature=0.1,
            api_key=os.getenv("OPENAI_API_KEY")
        ).bind_tools([answer_comparison_result])
        
        # Create system prompt
        system_prompt = """You are a mathematical answer evaluator. Your task is to determine if a given answer is mathematically equivalent to the ground truth answer, even if they are formatted differently. The given answer may contain the reasoning process. You need to identify the final answer and compare it to the ground truth answer.

Consider the following when comparing:
- Different number formats (1000 vs 1,000 vs 1000.0)
- Different units (if both have same units or no units)
- Mathematical equivalence (0.5 vs 1/2 vs 50%)
- Rounding differences (if the difference is due to reasonable rounding)

You must use the answer_comparison_result tool to report your decision."""
        
        # Create human prompt
        human_prompt = f"""Compare these two answers:

Given Answer: {answer}
Ground Truth: {ground_truth}

Are they mathematically equivalent? Use the answer_comparison_result tool to report your decision."""
        
        # Get LLM response
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=human_prompt)
        ]
        
        response = llm.invoke(messages)
        # Extract the tool call result
        if response.tool_calls:
            tool_call = response.tool_calls[0]
            if tool_call['name'] == 'answer_comparison_result':
                result = tool_call['args']
                self.logger.debug(f"LLM answer comparison: {result.get('reasoning', 'No reasoning provided')}")
                return result.get('is_correct', False)
        
        # If no tool call was made, fallback to False
        self.logger.warning("LLM did not use the answer_comparison_result tool. Defaulting to False.")
        return False
    
    def _exact_match_check(self, answer: str, ground_truth: str) -> bool:
        """Fallback exact match checking method.
        
        Args:
            answer: Extracted answer
            ground_truth: Ground truth answer
            
        Returns:
            True if answers match exactly (with basic normalization)
        """
        # Normalize answers
        answer = str(answer).strip().replace(',', '').replace('$', '')
        ground_truth = str(ground_truth).strip().replace(',', '').replace('$', '')
        
        # Try exact match first
        if answer == ground_truth:
            return True
        
        # Try numeric comparison
        try:
            answer_num = float(answer)
            truth_num = float(ground_truth)
            return abs(answer_num - truth_num) < 0.01
        except (ValueError, TypeError):
            pass
        
        # Handle percentage format
        if '%' in answer and '%' in ground_truth:
            try:
                answer_pct = float(answer.replace('%', ''))
                truth_pct = float(ground_truth.replace('%', ''))
                return abs(answer_pct - truth_pct) < 0.01
            except (ValueError, TypeError):
                pass
        
        return False
    
    def _count_tool_calls(self, response: Dict[str, Any]) -> int:
        """Count the number of tool calls in the response.
        
        Args:
            response: Agent response dictionary
            
        Returns:
            Number of tool calls
        """
        count = 0
        messages = response.get('messages', [])
        
        for message in messages:
            if hasattr(message, 'tool_calls') and message.tool_calls:
                count += len(message.tool_calls)
        
        return count
    
    def _count_tokens(self, response: Dict[str, Any]) -> int:
        """Count total tokens used in the response.
        
        Args:
            response: Agent response dictionary
            
        Returns:
            Total token count
        """
        total_tokens = 0
        messages = response.get('messages', [])
        
        for message in messages:
            if hasattr(message, 'response_metadata'):
                metadata = message.response_metadata
                if metadata and 'usage' in metadata:
                    usage = metadata['usage']
                    total_tokens += usage.get('input_tokens', 0)
                    total_tokens += usage.get('output_tokens', 0)
        
        return total_tokens
    
    def get_metrics_summary(self) -> Dict[str, Any]:
        """Get summary of evaluation metrics.
        
        Returns:
            Dictionary with metrics summary
        """
        return self.metrics_collector.get_summary()
