"""
Thinking agent that coordinates other agents using MCTS and ReAct framework
"""

import math
import random
import re
import json
import numpy as np
from typing import Any, Dict, List, Optional, Tuple

from langchain.llms.base import BaseLLM

from agents.base_agent import BaseAgent
from agents.planner_agent import PlannerAgent
from CycleIE.agents.retriever_agent import RetrieverAgent
from CycleIE.agents.structurer_agent import StructurerAgent
from CycleIE.agents.extractor_agent import ExtractorAgent
from CycleIE.agents.reasoner_agent import ReasonerAgent
from core.state import DocumentQAState
from core.document_manager import DocumentManager
from utils.task_utils import (
    optimize_original_steps,
    information_verifier, 
    information_refiner,
    get_llm_response,
    extract_json_from_text
)
from utils.mcts import MCTSNode

class WorkflowControl(BaseAgent):
    """Core controller agent that coordinates all other agents using MCTS and ReAct framework."""

    
    def __init__(self, state: DocumentQAState, doc_manager: DocumentManager, llm: Optional[BaseLLM] = None):
        super().__init__(state, llm)
        self.planner = PlannerAgent(state, llm)
        self.retriever = RetrieverAgent(state, doc_manager, llm)
        self.structurer = StructurerAgent(state, llm)
        self.extractor = ExtractorAgent(state, llm)
        self.reasoner = ReasonerAgent(state, llm)
        self.doc_manager = doc_manager
        self.verifier = None  # Will be initialized in execute
        self.re_extractor = None   # Will be initialized in execute
        self.result = None    # Store execution results
        
        # MCTS parameters
        self.max_iterations = 10
        self.exploration_constant = 1.414  # sqrt(2)
        
        # Define available actions and correct sequences
        self.possible_actions = ["retrieve", "structure", "extract", "verify", "re-extract", "reason", "refine"]
        # Define transition model between actions
        self.action_transitions = {
            "retrieve": ["structure"],
            "structure": ["extract"],
            "extract": ["verify"],
            "verify": ["retrieve", "extract", "refine", "reason"],
            "re-extract": ["verify", "reason"],
            "reason": ["reason"],
            "refine": ["retrieve"]
        }
    
    def execute(self, query: str, documents: List[str] = None, doc_mode: str = "paths") -> Dict:
        """Execute the complete document QA workflow with ReAct framework and MCTS optimization."""
        
        # Initialize verifier and re-extractor
        self.verifier = lambda step_info: information_verifier(step_info, self.llm, self.state)
        self.re_extractor = lambda step_info: information_refiner(step_info, self.llm, self.state)

        
        # Initial planning - pass document_paths to the planner
        plan = self.planner.execute(query, documents=documents, doc_mode=doc_mode)
        
        # Optimize steps
        plan["steps"] = optimize_original_steps(plan["steps"])
        
        self.state.steps = plan["steps"]
        
        # Process each step in the plan
        step_idx = 0
        while step_idx < len(self.state.steps):

            step = self.state.steps[step_idx]
            self.state.current_step = step_idx + 1
            self.add_thought(f"**Step {step_idx + 1}: {step.get('description', '')}**")
            
            # Use ReAct framework to determine the sequence of actions
            current_action = "retrieve" if step.get("requires_document", False) else "reason"
            max_iterations = 40  # Prevent infinite loops
            iteration = 0
            
            while iteration < max_iterations:
                iteration += 1
                
                # Execute the current action
                result = self._execute_action(current_action, step, doc_mode=doc_mode)
                
                # Check if we need to break the loop
                if current_action == "reason" and not (result and isinstance(result, dict) and result.get("needs_more_info", False)):
                    break
                
                # Special handling for refine action
                if current_action == "refine":
                    # After refine, steps might have changed, so restart from the current index
                    # But don't increment step_idx yet as we want to process the possibly new step at this index
                    if self.state.steps[step_idx]["requires_document"] == True:
                        current_action = "retrieve"
                    else:
                        current_action = "reason"
                
                # Use MCTS to optimize action selection
                next_action = self._mcts_search(step, current_action)
                if next_action:
                    current_action = next_action
            
            # Move to the next step only after finishing processing the current one
            step_idx += 1
        
        # The last step should now be the final synthesis step
        last_step = self.state.steps[-1]
        final_result = self.state.step_answers.get(last_step["step_number"], None)
        
        # If we don't have a result from the last step, use extracted info as fallback
        if not final_result:
            self.add_thought("Result from the last step is not available, using alternative approach to generate the final answer...")
            last_step_idx = self.state.current_step
            final_result = self.reasoner.execute(
                step_info=last_step
            )
        
        # Store the result
        self.result = final_result
        
        return final_result
    
    def _update_search_keywords(self, step: Dict, feedback: Dict):
        """Update search keywords based on feedback."""
        if "refinement_suggestions" in feedback and feedback["refinement_suggestions"]:
            suggestions = feedback["refinement_suggestions"]
            self.add_thought(f"Update the retrieval keywords based on the feedback: {suggestions}")
            
            # Extract potential keywords from suggestions
            keywords = re.findall(r'\b\w+\b', suggestions)
            keywords = [k for k in keywords if len(k) > 3 and k.lower() not in 
                        ['this', 'that', 'there', 'should', 'would', 'could', 'about', 'which']]
            
            if keywords:
                # Add new keywords to the step
                if "search_keywords" not in step:
                    step["search_keywords"] = []
                step["search_keywords"].extend(keywords[:3])  # Add top 3 new keywords
                step["search_keywords"] = list(set(step["search_keywords"]))  # Remove duplicates
                self.add_thought(f"update the retrieval keywords: {', '.join(step['search_keywords'])}")
    
    def _mcts_search(self, step: Dict, current_action: str) -> Optional[str]:
        """Perform MCTS to find the best next action."""
        # Check if we're already in a loop
        retrieval_attempts = step.get("retrieve_attempts", 0)
        verification_attempts = step.get("verification_attempts", 0)
        refine_attempts = step.get("refine_attempts", 0)
        total_attempts = retrieval_attempts + verification_attempts + refine_attempts
        in_cycle = retrieval_attempts > 1 and verification_attempts > 1
        
        # Get verification results
        step_idx = self.state.current_step
        verification_result = self.state.verification_results.get(step_idx, {})
        verification_passed = verification_result.get("verification_passed", False)
        
        # Check if we've reached attempt limits, force skip respective actions
        if retrieval_attempts >= 2 and current_action == "retrieve":
            print("Retrieval attempts reached maximum, not conducting further retrievals")
            # If there's extracted information, try to verify or reason
            if self.state.extracted_info.get(step_idx):
                if not verification_result:
                    return "verify"
                else:
                    return "reason"
            # If no extracted info but retrieved documents exist, try structuring and extraction
            elif self.state.retrieved_docs.get(step_idx):
                if not self.state.data_structure.get(step_idx):
                    return "structure"
                else:
                    return "extract"
            # If nothing else, try reasoning
            else:
                return "reason"
                
        if refine_attempts >= 2 and current_action == "refine":
            self.add_thought("Replanning attempts reached maximum, switching to reasoning")
            return "reason"
        
        # Strategy intervention to avoid loops
        if in_cycle and current_action in ["retrieve", "structure", "extract"]:
            # If trapped in a loop and still in basic operations, force move to higher-level operations
            if verification_passed:
                # If verification passed, directly enter reasoning phase
                return "reason"
            else:
                # If verification failed, try refining or retrieving new information
                return "re-extract" if random.random() > 0.5 else "retrieve"
                
        # Create root node for MCTS
        root = MCTSNode(state=current_action, parent=None)
        
        # Run MCTS iterations
        for _ in range(self.max_iterations):
            # Selection phase
            leaf = self._select(root)
            
            # Expansion phase
            child = self._expand_with_react(leaf)
            
            # Simulation phase
            reward = self._simulate_with_react(child)
            
            # Backpropagation phase
            self._backpropagate(child, reward)
        
        # Choose the best action from the root node
        if not root.children:
            return None
            
        # Best next action based on highest visit count or highest value
        best_child = max(root.children, key=lambda c: c.visits)
        return best_child.state
    
    def _execute_action(self, current_action: str, step: Dict, doc_mode: str = "paths"):
        """Execute the current action based on the ReAct framework."""
        if not step:
            return None
        
        step_idx = self.state.current_step
        result = None

        # Process based on action type
        if current_action == "retrieve":
            # Track retrieval attempts
            step["retrieve_attempts"] = step.get("retrieve_attempts", 0) + 1
            
            # Prepare for retrieval
            query = step.get("detail", "")
            keywords = step.get("search_keywords", [])
            
            # Execute retrieval
            result = self.retriever.execute(
                query=query,
                keywords=keywords,
                step_info=step,
                doc_mode=doc_mode
            )
            
            # Let the agent know if a refine is needed
            if self._should_refine(step):
                return {"needs_refine": True}
                
        elif current_action == "structure":
            # Structure the retrieved documents
            doc_content = self.state.retrieved_docs.get(step_idx, "")
            if doc_content:
                result = self.structurer.execute(
                    document_content=doc_content,
                    step_info=step
                )
            
        elif current_action == "extract":
            # Extract information from the structured content
            structured_content = self.state.data_structure.get(step_idx, "")
            if structured_content:
                result = self.extractor.execute(
                    structured_content=structured_content,
                    step_info=step
                )
        
        elif current_action == "verify":
            # Track verification attempts
            step["verification_attempts"] = step.get("verification_attempts", 0) + 1
            
            # Verify the extracted information
            extracted_info = self.state.extracted_info.get(step_idx, "")
            if extracted_info:
                step_info = {
                    "extracted_info": extracted_info,
                    "query": step.get("detail", ""),
                    "step_number": step_idx
                }
                result = self.verifier(step_info)
                
                # Update search keywords based on verification feedback
                if result and "verification_passed" in result and not result["verification_passed"]:
                    self._update_search_keywords(step, result)
        
        elif current_action == "re-extract":
            # Refine the extracted information
            extracted_info = self.state.extracted_info.get(step_idx, "")
            verification_result = self.state.verification_results.get(step_idx, {})
            
            if extracted_info and verification_result:
                step_info = {
                    "extracted_info": extracted_info,
                    "verification_result": verification_result,
                    "query": step.get("detail", ""),
                    "step_number": step_idx
                }
                result = self.re_extractor(step_info)
        
        elif current_action == "reason":
            # Perform reasoning based on extracted information
            result = self.reasoner.execute(
                step_info=step
            )
            
            # Store step answer
            if result and not isinstance(result, dict):
                self.state.add_step_answer(step_idx, result)
        
        elif current_action == "refine":
            # Track refine attempts
            step["refine_attempts"] = step.get("refine_attempts", 0) + 1
            
            # Optimize the current step
            new_step = self._optimize_step(step)
            if new_step and new_step != step:
                # Replace the current step with the optimized version
                self.state.steps[step_idx - 1] = new_step
                self.add_thought(f"Optimized step {step_idx}: {new_step.get('description', '')}")
                result = {"optimized": True}
        
        return result
    
    def _select(self, node: MCTSNode) -> MCTSNode:
        """Select phase of MCTS, using UCB1 formula."""
        while node.children:
            # If not all children have been visited at least once, choose the unvisited one
            unvisited = [child for child in node.children if child.visits == 0]
            if unvisited:
                return random.choice(unvisited)
            
            # Choose the child with the highest UCB value
            node = max(node.children, key=lambda child: child.value / child.visits + 
                       self.exploration_constant * math.sqrt(math.log(node.visits) / child.visits))
        return node
    
    def _expand_with_react(self, node: MCTSNode) -> MCTSNode:
        """Expansion phase of MCTS, with ReAct-based constraints on valid actions."""
        current_action = node.state
        
        # Get valid next actions based on ReAct framework
        valid_next_actions = self.action_transitions.get(current_action, [])
        
        # If we've already created all possible children, just return the node
        if len(node.children) == len(valid_next_actions):
            return node
        
        # Only expand with actions we haven't tried yet from this node
        tried_actions = [child.state for child in node.children]
        available_actions = [action for action in valid_next_actions if action not in tried_actions]
        
        if not available_actions:
            return node
        
        # Choose an action to expand
        new_action = random.choice(available_actions)
        
        # Create a new child node with the chosen action
        child = MCTSNode(state=new_action, parent=node)
        node.children.append(child)
        
        return child
    
    def _simulate_with_react(self, node: MCTSNode) -> float:
        """Simulation phase of MCTS, with ReAct-based rewards."""
        step = self.state.steps[self.state.current_step - 1]
        action_sequence = []
        current_node = node
        
        # Build the action sequence backward from the leaf to the root
        while current_node:
            action_sequence.insert(0, current_node.state)
            current_node = current_node.parent
        
        # Remove the root action as it's the current action
        if action_sequence:
            action_sequence.pop(0)
        
        # Base reward
        reward = 0.5
        
        # Simulate the potential outcome of this action sequence
        if not action_sequence:
            # If no actions to simulate, use default reward
            return reward
        
        # The first action in the sequence is what we're evaluating
        next_action = action_sequence[0]
        
        # Adjust reward based on the context and action
        step_idx = self.state.current_step
        
        # Check current state to adjust rewards
        has_retrieved = step.get("retrieve_attempts", 0) > 0
        has_extracted = bool(self.state.extracted_info.get(step_idx))
        has_verified = bool(self.state.verification_results.get(step_idx))
        verification_passed = has_verified and self.state.verification_results[step_idx].get("verification_passed", False)
        
        # Logic for rewarding appropriate next actions based on current state
        
        # If we haven't retrieved yet, retrieval should be rewarded
        if not has_retrieved and next_action == "retrieve":
            reward += 0.2
        
        # If we have documents but haven't extracted, extraction path should be rewarded
        if has_retrieved and not has_extracted:
            if next_action == "structure":
                reward += 0.3
            elif next_action == "extract" and bool(self.state.data_structure.get(step_idx)):
                reward += 0.3
        
        # If we have extracted but not verified, verification should be rewarded
        if has_extracted and not has_verified and next_action == "verify":
            reward += 0.3
        
        # If verification failed, refinement or retrieval should be rewarded
        if has_verified and not verification_passed:
            if next_action == "re-extract":
                reward += 0.3
            elif next_action == "retrieve" and step.get("retrieve_attempts", 0) < 2:
                reward += 0.2
        
        # If verification passed, reasoning should be rewarded
        if verification_passed and next_action == "reason":
            reward += 0.4
        
        # If we have too many retrieval attempts without success, maybe we need to refine
        if step.get("retrieve_attempts", 0) >= 2 and next_action == "refine":
            reward += 0.2
        
        # Penalize loops in the process
        retrieval_attempts = step.get("retrieve_attempts", 0)
        verification_attempts = step.get("verification_attempts", 0)
        
        is_potential_loop = retrieval_attempts > 1 and verification_attempts > 1
        if is_potential_loop:
            # Penalize continuing the retrieve-verify loop
            if next_action in ["retrieve", "verify"]:
                reward -= 0.2
            # Reward breaking out of the loop
            if next_action in ["reason", "refine"]:
                reward += 0.3
        
        # Clamp reward between 0 and 1
        return max(0.1, min(reward, 1.0))
    
    def _backpropagate(self, node: MCTSNode, reward: float):
        """Backpropagation phase of MCTS."""
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent
            
    def _should_retrieve_more(self, step: Dict) -> bool:
        """Determine if more document retrieval is needed."""
        step_idx = self.state.current_step
        
        # Check if we already have verification results indicating if info is sufficient
        verification_result = self.state.verification_results.get(step_idx, {})
        if verification_result:
            # If verification passed, no need for more retrieval
            if verification_result.get("verification_passed", False):
                return False
                
            # If verification suggested more info needed, retrieve more
            if verification_result.get("needs_more_information", False):
                return True
        
        # If we haven't retrieved anything yet, we should retrieve
        if step.get("retrieve_attempts", 0) == 0:
            return True
            
        # If we've already tried retrieval multiple times, stop to avoid loops
        if step.get("retrieve_attempts", 0) >= 2:
            self.add_thought("Already attempted retrieval multiple times, not retrieving more.")
            return False
            
        # Check if we have adequate information for this step
        has_retrieved_docs = bool(self.state.retrieved_docs.get(step_idx))
        has_extracted_info = bool(self.state.extracted_info.get(step_idx))
        
        # If we have no retrieved docs and no extracted info, we should retrieve
        if not has_retrieved_docs and not has_extracted_info:
            return True
            
        # If the step explicitly requires more context, retrieve more
        if step.get("requires_more_context", False):
            return True
            
        # Default: don't retrieve more if we already have some information
        return False
        
    def _should_refine(self, step: Dict) -> bool:
        """Determine if we need to refine based on retrieval results."""
        step_idx = self.state.current_step
        
        # If we've already tried retrieving twice but got nothing useful, consider refinening
        if step.get("retrieve_attempts", 0) >= 2:
            has_docs = bool(self.state.retrieved_docs.get(step_idx))
            
            if not has_docs:
                self.add_thought("Multiple retrieval attempts yielded no useful documents. Considering refining.")
                return True
        
        # Check if this step's goal is still achievable
        if step.get("requires_document", True) and not self.state.retrieved_docs.get(step_idx):
            # The step requires documents but we couldn't retrieve any
            return True
            
        # If significant number of retrieval and verification attempts with no progress, refine
        retrieval_attempts = step.get("retrieve_attempts", 0)
        verification_attempts = step.get("verification_attempts", 0)
        
        if retrieval_attempts > 1 and verification_attempts > 1:
            verification_result = self.state.verification_results.get(step_idx, {})
            # If still failing verification after multiple attempts, refine
            if not verification_result.get("verification_passed", False):
                return True
                
        # Default: no need to refine
        return False
    
    def _optimize_step(self, step: Dict) -> Optional[Dict]:
        """Optimize a step based on what we've learned."""
        step_idx = self.state.current_step
        
        # Skip optimization if this isn't a problematic step
        if not self._should_refine(step):
            return None
            
        # Create an optimized copy of the step
        optimized_step = step.copy()
        
        # Adjust the step based on what we've learned
        verification_result = self.state.verification_results.get(step_idx, {})
        feedback = verification_result.get("feedback", "")
        
        # Update the step description if we have verification feedback
        if feedback:
            # Extract key insights from the feedback
            system_prompt = """
            You are an AI assistant that helps optimize search and reasoning steps.
            Extract key information needs and create a more focused step description.
            """
            
            user_prompt = f"""
            Original step: {step.get('description', '')}
            Detail: {step.get('detail', '')}
            
            Verification feedback: {feedback}
            
            Based on this feedback, create:
            1. A more focused step description
            2. A more detailed query
            3. 3-5 specific keywords for search
            
            Format your answer as a JSON with keys: description, detail, keywords
            """
            
            response = get_llm_response(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                model="gpt-4",  # Use a good model for optimization
                temperature=0.3
            )
            
            # Extract the JSON response
            try:
                improved_step = extract_json_from_text(response)
                if improved_step:
                    # Update step with improved content
                    if "description" in improved_step:
                        optimized_step["description"] = improved_step["description"]
                    if "detail" in improved_step:
                        optimized_step["detail"] = improved_step["detail"]
                    if "keywords" in improved_step:
                        optimized_step["search_keywords"] = improved_step["keywords"]
            except Exception as e:
                self.add_thought(f"Error optimizing step: {str(e)}")
                
        # Reset attempt counters
        optimized_step["retrieve_attempts"] = 0
        optimized_step["verification_attempts"] = 0
        
        # Adjust requirements based on what we've learned
        if step.get("retrieve_attempts", 0) >= 2 and not self.state.retrieved_docs.get(step_idx):
            # If we couldn't find relevant documents after multiple attempts,
            # maybe this step doesn't actually require document retrieval
            optimized_step["requires_document"] = False
            self.add_thought("Adjusted step to no longer require document retrieval.")
            
        return optimized_step 