"""
Decision Context Capture Module

This module captures and aggregates real-time decision contexts from execution agents
to provide strategic insights for the orchestration agent.
"""

from typing import Dict, Any, List
from orchestrator_maze_implementation.state.maze_state import MazeState
from orchestrator_maze_implementation.agents.maze_execution_agent import (
    _calculate_movement_scores, _calculate_unexplored_directions, 
    _calculate_dead_end_confidence, _generate_dynamic_behavioral_prompts,
    _extract_strategic_waypoints, _get_next_position, safe_get
)


class DecisionContextCapture:
    """Capture and aggregate decision-time context from execution agents"""
    
    @staticmethod
    def capture_agent_decision_context(state: MazeState, agent_id: str) -> Dict[str, Any]:
        """Capture the full decision context that execution agent sees"""
        try:
            # Get maze wrapper for this agent
            maze_wrappers = safe_get(state, "maze_wrappers", {}, "state")
            if agent_id not in maze_wrappers:
                return {"error": f"Agent {agent_id} not found in maze_wrappers"}
            
            maze_wrapper = maze_wrappers[agent_id]
            
            # Get current position and possible moves
            current_position = maze_wrapper.get_agent_position()
            possible_moves = maze_wrapper.get_possible_moves()
            
            # Get movement history and visited positions
            move_history = getattr(maze_wrapper, 'move_history', [])
            previously_visited_tiles = move_history.copy()
            recent_positions = move_history[-6:] if len(move_history) >= 6 else move_history
            
            # Get marked dead ends
            marked_dead_ends = list(maze_wrapper.get_marked_dead_ends())
            
            # Calculate unexplored directions
            unexplored_directions = _calculate_unexplored_directions(
                current_position, possible_moves, previously_visited_tiles
            )
            
            # Get other agents' positions for coordination
            other_agents_positions = []
            all_agents = safe_get(state, "all_agents", [], "state")
            
            for other_agent_id in all_agents:
                if other_agent_id != agent_id and other_agent_id in maze_wrappers:
                    other_wrapper = maze_wrappers[other_agent_id]
                    other_history = getattr(other_wrapper, 'move_history', [])
                    other_agents_positions.extend(other_history)
            
            # Generate dynamic weights and prompts
            dynamic_prompts, weights = _generate_dynamic_behavioral_prompts(state, agent_id)
            
            # Calculate movement scores for each direction
            movement_scores = _calculate_movement_scores(
                current_position, possible_moves, previously_visited_tiles,
                recent_positions, other_agents_positions, weights, marked_dead_ends
            )
            
            # Calculate dead end confidence
            dead_end_confidence = _calculate_dead_end_confidence(
                current_position, possible_moves, previously_visited_tiles
            )
            
            # Get backtracking status
            backtracking_state = safe_get(state, "agent_backtracking_state", {}, "state")
            agent_backtrack = safe_get(backtracking_state, agent_id, {}, "backtracking_state") if backtracking_state else {}
            
            # Extract strategic waypoints from teammates
            strategic_waypoints = set()
            for other_agent_id in all_agents:
                if other_agent_id != agent_id and other_agent_id in maze_wrappers:
                    other_wrapper = maze_wrappers[other_agent_id]
                    other_history = getattr(other_wrapper, 'move_history', [])
                    waypoints = _extract_strategic_waypoints(other_wrapper, other_history)
                    strategic_waypoints.update(waypoints)
            
            return {
                "agent_id": agent_id,
                "current_position": current_position,
                "possible_moves": possible_moves,
                "unexplored_directions": unexplored_directions,
                "movement_scores": movement_scores,
                "efficiency_penalties": recent_positions,
                "teammate_positions": other_agents_positions,
                "dynamic_weights": weights,
                "backtracking_status": {
                    "is_backtracking": safe_get(agent_backtrack, "is_backtracking", False, "agent_backtrack"),
                    "lock_mode": safe_get(agent_backtrack, "lock_mode", False, "agent_backtrack")
                },
                "dead_end_confidence": dead_end_confidence,
                "strategic_waypoints": list(strategic_waypoints),
                "marked_dead_ends": marked_dead_ends,
                "previously_visited": previously_visited_tiles
            }
            
        except Exception as e:
            return {"error": f"Failed to capture context for {agent_id}: {str(e)}"}
    
    @staticmethod
    def aggregate_all_agent_contexts(state: MazeState) -> Dict[str, Any]:
        """Aggregate decision contexts from all agents"""
        agent_contexts = {}
        maze_wrappers = safe_get(state, "maze_wrappers", {}, "state")
        
        for agent_id in maze_wrappers.keys():
            agent_contexts[agent_id] = DecisionContextCapture.capture_agent_decision_context(state, agent_id)
        
        # Identify global patterns
        global_patterns = DecisionContextCapture._identify_global_patterns(agent_contexts)
        optimization_opportunities = DecisionContextCapture._find_optimization_opportunities(agent_contexts)
        
        return {
            "individual_contexts": agent_contexts,
            "global_patterns": global_patterns,
            "optimization_opportunities": optimization_opportunities
        }
    
    @staticmethod
    def _identify_global_patterns(agent_contexts: Dict) -> Dict[str, Any]:
        """Identify patterns across all agents that individual agents can't see"""
        convergence_zones = []
        exploration_gaps = []
        efficiency_conflicts = []
        
        # Find areas where multiple agents are converging
        agent_positions = {}
        for agent_id, context in agent_contexts.items():
            if "error" not in context:
                agent_positions[agent_id] = context["current_position"]
        
        # Check for convergence (agents within 3 cells of each other)
        for agent1, pos1 in agent_positions.items():
            for agent2, pos2 in agent_positions.items():
                if agent1 < agent2:  # Avoid duplicates
                    distance = abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
                    if distance <= 3:
                        convergence_zones.append({
                            "agents": [agent1, agent2],
                            "positions": [pos1, pos2],
                            "distance": distance
                        })
        
        # Find efficiency conflicts (high penalty blocking good moves)
        for agent_id, context in agent_contexts.items():
            if "error" not in context:
                movement_scores = context.get("movement_scores", {})
                efficiency_penalties = context.get("efficiency_penalties", [])
                current_pos = context.get("current_position")
                
                for direction, score in movement_scores.items():
                    next_pos = _get_next_position(current_pos, direction)
                    
                    # If score is low due to efficiency penalty but direction leads to unexplored area
                    if (score < 0.5 and 
                        next_pos in efficiency_penalties and
                        direction in context.get("unexplored_directions", [])):
                        
                        efficiency_conflicts.append({
                            "agent_id": agent_id,
                            "position": current_pos,
                            "blocked_direction": direction,
                            "score": score,
                            "reason": "efficiency_penalty_blocking_exploration"
                        })
        
        return {
            "convergence_zones": convergence_zones,
            "exploration_gaps": exploration_gaps,
            "efficiency_conflicts": efficiency_conflicts
        }
    
    @staticmethod
    def _find_optimization_opportunities(agent_contexts: Dict) -> List[Dict]:
        """Find coordination and optimization opportunities"""
        opportunities = []
        
        # Find exploration coordination opportunities
        for agent1_id, context1 in agent_contexts.items():
            if "error" in context1:
                continue
                
            for agent2_id, context2 in agent_contexts.items():
                if agent1_id >= agent2_id or "error" in context2:
                    continue
                
                pos1 = context1["current_position"]
                pos2 = context2["current_position"]
                distance = abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
                
                # If agents are close and have overlapping unexplored directions
                if distance <= 4:
                    unexplored1 = set(context1.get("unexplored_directions", []))
                    unexplored2 = set(context2.get("unexplored_directions", []))
                    overlap = unexplored1 & unexplored2
                    
                    if len(overlap) > 1:  # Multiple overlapping directions
                        opportunities.append({
                            "type": "exploration_coordination",
                            "agents": [agent1_id, agent2_id],
                            "overlapping_directions": list(overlap),
                            "recommendation": f"Coordinate exploration: divide {list(overlap)} between agents"
                        })
        
        # Find efficiency override opportunities
        for agent_id, context in agent_contexts.items():
            if "error" not in context:
                movement_scores = context.get("movement_scores", {})
                weights = context.get("dynamic_weights", {})
                efficiency_weight = weights.get("efficiency_weight", 1.0)
                
                # If efficiency weight is very high and blocking exploration
                if efficiency_weight > 1.5:
                    low_score_unexplored = []
                    for direction, score in movement_scores.items():
                        if (score < 0.7 and 
                            direction in context.get("unexplored_directions", [])):
                            low_score_unexplored.append(direction)
                    
                    if low_score_unexplored:
                        opportunities.append({
                            "type": "efficiency_override",
                            "agent_id": agent_id,
                            "directions": low_score_unexplored,
                            "efficiency_weight": efficiency_weight,
                            "recommendation": f"Consider overriding efficiency penalty for {low_score_unexplored[0]} direction"
                        })
        
        return opportunities
