"""
Dynamic Prompt Injection Module for Free Energy-Driven Agent Optimization

This module provides dynamic prompt generation based on real-time free energy metrics
to optimize agent behavior through context-aware behavioral modulation.
"""

import numpy as np
from typing import Dict, List
from orchestrator_maze_implementation.state.maze_state import MazeState
from orchestrator_maze_implementation.config.config_service import ConfigService


def calculate_fe_trend(fe_history: List[float], window_size: int = 3) -> float:
    """Calculate the trend in free energy values over recent history"""
    if len(fe_history) < window_size:
        return 0.0
    
    recent_values = fe_history[-window_size:]
    if len(recent_values) < 2:
        return 0.0
    
    # Simple linear trend calculation
    x = np.arange(len(recent_values))
    y = np.array(recent_values)
    
    # Calculate slope using least squares
    n = len(recent_values)
    if n < 2:
        return 0.0
    
    slope = (n * np.sum(x * y) - np.sum(x) * np.sum(y)) / (n * np.sum(x * x) - np.sum(x) ** 2)
    return slope


def calculate_second_derivative(values: List[float]) -> float:
    """Calculate second derivative (acceleration) of a time series"""
    if len(values) < 3:
        return 0.0
    
    # Simple finite difference approximation for second derivative
    derivatives = []
    for i in range(1, len(values) - 1):
        second_deriv = values[i + 1] - 2 * values[i] + values[i - 1]
        derivatives.append(second_deriv)
    
    return np.mean(derivatives) if derivatives else 0.0


def classify_agent_type(fe_data: Dict[str, float]) -> str:
    """Classify agent type based on free energy profile"""
    epistemic = fe_data.get("epistemic_uncertainty", 0.5)
    pragmatic = fe_data.get("pragmatic_cost", 0.5)
    
    # Classification logic based on relative values
    if epistemic > 0.6 and pragmatic < 0.4:
        return "explorer"
    elif epistemic < 0.4 and pragmatic > 0.6:
        return "optimizer"
    else:
        return "balanced"
    
def generate_dynamic_execution_weights(fe_data: Dict, agent_state: Dict) -> Dict[str, float]:
    """Generate decision weight modifiers based on free energy metrics"""
    
    epistemic = fe_data.get("epistemic_uncertainty", 0.5)
    pragmatic = fe_data.get("pragmatic_cost", 0.5)
    fe_trend = calculate_fe_trend(agent_state.get("fe_history", []))
    
    weights = {
        "exploration_weight": 1.0,      # Base weight for unexplored directions
        "efficiency_weight": 1.0,       # Weight for avoiding recent positions
        "teammate_avoidance": 1.0,      # Weight for avoiding teammate areas
        "backtrack_threshold": 0.7,     # Threshold for triggering backtracking
        "dead_end_confidence": 0.8      # Confidence threshold for marking dead ends
    }
    
    # Adjust weights based on FE metrics
    if epistemic > 0.85:  # High uncertainty
        weights["exploration_weight"] = 2.0  # Double exploration priority
        weights["efficiency_weight"] = 2.0   # INCREASE efficiency to prevent oscillation loops
        weights["dead_end_confidence"] = 0.9  # Be more cautious about dead ends
        weights["teammate_avoidance"] = 3.0   # Mild teammate avoidance when uncertain
    
    elif epistemic < 0.4:  # Low uncertainty  
        weights["exploration_weight"] = 0.8
        weights["efficiency_weight"] = 1.5   # Prioritize efficient movement
        weights["dead_end_confidence"] = 0.6  # More confident in dead end marking
        weights["teammate_avoidance"] = 1.1   # Very mild teammate avoidance when confident
    
    if pragmatic > 0.35:  # High cost
        weights["efficiency_weight"] *= 1.5
        weights["backtrack_threshold"] = 0.5  # Lower threshold for backtracking
        weights["teammate_avoidance"] = 1.9  # Lower teammate avoidance for efficiency
        
    if fe_trend < -0.005:  # Declining performance
        weights["backtrack_threshold"] = 0.3  # Much lower threshold
        weights["exploration_weight"] *= 0.7   # Reduce stubbornness
        weights["teammate_avoidance"] *= 1.2   # Moderate teammate avoidance when struggling
    
    return weights

def generate_dynamic_execution_prompts(fe_data: Dict, agent_state: Dict) -> Dict[str, str]:
    """Generate dynamic prompt modifiers based on free energy metrics (maintains backward compatibility)"""
    
    # Generate weights first
    weights = generate_dynamic_execution_weights(fe_data, agent_state)
    
    prompts = {}
    
    # Exploration emphasis
    if weights["exploration_weight"] > 1.5:
        prompts["exploration_modifier"] = "🔍 EXPLORE: Strongly prioritize unexplored paths (2x weight)"
    elif weights["exploration_weight"] < 0.9:
        prompts["exploration_modifier"] = "🎯 FOCUS: Prioritize known efficient paths"
    
    # Efficiency emphasis  
    if weights["efficiency_weight"] > 1.3:
        prompts["efficiency_modifier"] = "⚡ EFFICIENCY: Avoid last 5 positions (1.5x penalty)"
    
    # Backtracking readiness
    if weights["backtrack_threshold"] < 0.5:
        prompts["adaptation_signal"] = "🔄 ADAPTATION: Consider backtracking earlier if stuck"
    
    return prompts

def format_weighted_guidance(weights: Dict[str, float]) -> str:
    """Format weights as concise guidance modifiers"""
    
    lines = []
    
    # Exploration emphasis
    if weights["exploration_weight"] > 1.5:
        lines.append("🔍 EXPLORE: Strongly prioritize unexplored paths (2x weight)")
    elif weights["exploration_weight"] < 0.9:
        lines.append("🎯 FOCUS: Prioritize known efficient paths")
    
    # Efficiency emphasis  
    if weights["efficiency_weight"] > 1.3:
        lines.append("⚡ EFFICIENCY: Avoid last 5 positions (1.5x penalty)")
    
    # Backtracking readiness
    if weights["backtrack_threshold"] < 0.5:
        lines.append("🔄 ADAPTATION: Consider backtracking earlier if stuck")
    
    return "\n".join(lines) if lines else "Standard exploration mode"

def generate_orchestration_prompts(system_fe_metrics: Dict) -> str:
    """Generate system-level optimization directives"""
    
    system_entropy = system_fe_metrics.get("system_entropy", 0.5)
    exploration_coverage = system_fe_metrics.get("exploration_coverage", 0.5)
    
    if system_entropy < 0.4:  # System converging too quickly
        return """
        SYSTEM DIRECTIVE - DIVERGENCE INJECTION:
        - MANDATE: Force agent specialization into different maze quadrants
        - INJECT: Random exploration seeds to break convergence
        - OVERRIDE: Individual agent efficiency preferences for diversity
        """
    elif exploration_coverage < 0.5:  # Poor overall coverage
        return """
        SYSTEM DIRECTIVE - COVERAGE OPTIMIZATION:
        - REASSIGN: Agents to unexplored maze regions
        - PENALIZE: Agents exploring already-covered areas
        - REWARD: Agents discovering new maze structures
        """
    else:
        return ""

def generate_adaptive_agent_guidance(agent_fe_profiles: Dict) -> Dict[str, str]:
    """Generate individualized guidance based on agent free energy profiles"""
    
    guidance = {}
    
    for agent_id, fe_data in agent_fe_profiles.items():
        agent_type = classify_agent_type(fe_data)
        
        if agent_type == "explorer":  # High epistemic, low pragmatic
            guidance[agent_id] = """
            ROLE: PRIMARY EXPLORER
            - MISSION: Discover new maze regions
            - TOLERANCE: Accept higher movement costs for information gain
            - DIRECTIVE: Prioritize breadth over depth in exploration
            """
        elif agent_type == "optimizer":  # Low epistemic, high pragmatic focus
            guidance[agent_id] = """
            ROLE: EFFICIENCY SPECIALIST
            - MISSION: Optimize known paths and eliminate redundancy
            - TOLERANCE: Lower exploration in favor of goal completion
            - DIRECTIVE: Focus on shortest-path solutions in known regions
            """
        elif agent_type == "balanced":
            guidance[agent_id] = """
            ROLE: ADAPTIVE COORDINATOR
            - MISSION: Balance exploration and efficiency based on context
            - TOLERANCE: Switch between explorer/optimizer as needed
            - DIRECTIVE: Respond to system-level needs dynamically
            """
        else:
            guidance[agent_id] = ""
    
    return guidance


def calculate_temporal_fe_derivatives(agent_id: str, state: MazeState) -> Dict:
    """Calculate rate of change in free energy metrics"""
    
    fe_history = state.get("entropy_history", [])
    if not fe_history or len(fe_history) < 3:
        return {"trend": 0, "acceleration": 0, "volatility": 0}
    
    recent_fe = []
    for entry in fe_history[-5:]:
        if "agent_fe_scores" in entry and agent_id in entry["agent_fe_scores"]:
            fe_value = entry["agent_fe_scores"][agent_id].get("total_fe", 1.0)
            recent_fe.append(fe_value)
    
    if len(recent_fe) < 3:
        return {"trend": 0, "acceleration": 0, "volatility": 0}
    
    # Calculate derivatives
    trend = (recent_fe[-1] - recent_fe[0]) / len(recent_fe)
    acceleration = calculate_second_derivative(recent_fe)
    volatility = float(np.std(recent_fe)) if len(recent_fe) > 1 else 0.0
    
    return {
        "trend": trend,
        "acceleration": acceleration, 
        "volatility": volatility
    }


def calculate_system_fe_metrics(state: MazeState) -> Dict[str, float]:
    """Calculate system-level free energy metrics for orchestration"""
    
    # Get all agent FE scores from latest entropy history
    entropy_history = state.get("entropy_history", [])
    if not entropy_history:
        return {"system_entropy": 0.5, "convergence_rate": 0.5, "exploration_coverage": 0.5}
    
    latest_entry = entropy_history[-1]
    agent_fe_scores = latest_entry.get("agent_fe_scores", {})
    
    if not agent_fe_scores:
        return {"system_entropy": 0.5, "convergence_rate": 0.5, "exploration_coverage": 0.5}
    
    # Calculate system entropy (diversity of agent behaviors)
    fe_values = [data.get("total_fe", 1.0) for data in agent_fe_scores.values()]
    system_entropy = float(np.std(fe_values)) if len(fe_values) > 1 else 0.5
    
    # Calculate convergence rate (how similar agents are becoming)
    if len(entropy_history) >= 3:
        prev_entry = entropy_history[-3]
        prev_fe_scores = prev_entry.get("agent_fe_scores", {})
        
        if prev_fe_scores:
            prev_fe_values = [data.get("total_fe", 1.0) for data in prev_fe_scores.values()]
            prev_std = float(np.std(prev_fe_values)) if len(prev_fe_values) > 1 else 0.5
            current_std = float(np.std(fe_values)) if len(fe_values) > 1 else 0.5
            
            # Convergence rate: negative if becoming more similar (lower std)
            convergence_rate = (prev_std - current_std) / max(prev_std, 0.1)
        else:
            convergence_rate = 0.0
    else:
        convergence_rate = 0.0
    
    # Calculate exploration coverage (simplified - based on unique positions visited)
    maze_wrappers = state.get("maze_wrappers", {})
    all_visited_positions = set()
    total_positions = 0
    
    for wrapper in maze_wrappers.values():
        if hasattr(wrapper, 'move_history'):
            all_visited_positions.update(wrapper.move_history)
            total_positions += len(wrapper.move_history)
    
    # Rough estimation of exploration coverage
    unique_positions = len(all_visited_positions)
    exploration_coverage = unique_positions / max(total_positions, 1) if total_positions > 0 else 0.5
    
    return {
        "system_entropy": system_entropy,
        "convergence_rate": convergence_rate,
        "exploration_coverage": exploration_coverage
    }