import numpy as np
import math
import re
from typing import Dict, List, Union, Any, Tuple, Optional
from collections import Counter
from langchain_core.messages import AnyMessage
from langchain.schema import BaseOutputParser
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
from orchestrator_maze_implementation.state.maze_state import MazeState

####
## OUTPUT PARSER
####

class MessageContentParser(BaseOutputParser):
    """
    Parse out only essential content from agent messages for Shannon entropy calculation.
    Removes static legend info, tool call metadata, and empty content.
    """

    def parse(self, text: Union[str, BaseMessage]) -> str:
        # Handle different message types
        if isinstance(text, ToolMessage):
            content = text.content
        elif isinstance(text, AIMessage):
            # Skip AIMessages with only tool calls and no content
            if not text.content or not text.content.strip():
                return ""
            content = text.content
        elif hasattr(text, "content"):
            content = text.content
        else:
            content = str(text)

        if not content or not content.strip():
            return ""

        # Remove legend section and everything after it
        content = re.split(r"\nLegend:", content, maxsplit=1)[0]
        
        # Extract only essential lines with state information
        essential_patterns = [
            r"^Current position:",
            r"^Possible moves:",
            r"^At exit:",
            r"^Move \w+:",
            r"^New position:",
            r"Success!",
            r"Failed"
        ]
        
        pattern = re.compile("|".join(f"({p})" for p in essential_patterns), re.MULTILINE | re.IGNORECASE)
        lines = [line.strip() for line in content.splitlines() if pattern.search(line)]
        
        return " ".join(lines).strip()

    def parse_messages(self, messages: List[AnyMessage]) -> List[str]:
        """Parse a list of messages and return only non-empty essential content"""
        parsed_content = []
        for msg in messages:
            content = self.parse(msg)
            if content:  # Only add non-empty content
                parsed_content.append(content)
        return parsed_content

class FreeEnergyCalculator:
    """Free Energy Calculator for agent benchmarks"""

    def __init__(self, n: int = 5, cap: float = 2.0):
        self.n = n
        self.rolling_window = self.n # Track all results over last n steps (useful for rolling windows or advanced stats)
        self.cap = cap # capping-value for dynamic normalization of fe results

    def calculate_expected_free_energy(self, agent_id: str, state: MazeState) -> Dict[str, float]:
        """Calculate Expected Free Energy for a single agent"""

        if agent_id not in state["maze_wrappers"]:
            return {"error": f"Agent {agent_id} not found", "total_fe": 1.0}

        maze_wrapper = state["maze_wrappers"][agent_id]
        agent_messages = state["agent_messages"][agent_id]

        # Calculate epistemic uncertainty
        epistemic_uncertainty = self._calculate_epistemic_uncertainty(agent_messages, maze_wrapper)
        pragmatic_cost = self._calculate_pragmatic_cost(maze_wrapper)
        # Expected Free Energy = Epistemic Uncertainty + Pragmatic Cost
        total_fe = epistemic_uncertainty - pragmatic_cost

        # Return comprehensive metrics for optimization recommendations
        result = {
            "epistemic_uncertainty": epistemic_uncertainty,
            "pragmatic_cost": pragmatic_cost,
            "total_fe": min(total_fe, self.cap), # cap for normalization
        }
        return result

#####
# FREE ENERGY CORE FUNCTIONS: Epistemic Uncertainty & Pragmatic Cost
#####

    def _calculate_epistemic_uncertainty(self, agent_messages: List[AnyMessage], maze_wrapper) -> float:
        """Calculate agent's uncertainty about world state"""
        
        # 1. Calculate token entropy from recent messages
        recent_messages = agent_messages[-self.rolling_window:] if len(agent_messages) > self.rolling_window else agent_messages
        token_entropy = self._calculate_token_entropy(recent_messages)

        return token_entropy


    def _calculate_pragmatic_cost(self, maze_wrapper) -> float: 
        """Calculate pragmatic cost with detailed breakdown for optimization recommendations"""
        move_history = getattr(maze_wrapper, 'move_history', [])
        if not move_history:
            return 1.0  # Maximum cost for no movement
        print(f"DEBUG move_history:", move_history)

        # 1. Movement Efficiency: successful moves vs failed attempts
        failed_moves = getattr(maze_wrapper, 'failed_move_count', 0)
        print(f"DEBUG failed_moves:", failed_moves)

        total_attempts = len(move_history) + failed_moves
        move_success_rate = len(move_history) / max(total_attempts, 1)
        print(f"DEBUG move_success_rate:", move_success_rate)

        # 2. Exploration Progress: how much new territory is being discovered
        unique_positions = len(set(move_history))
        exploration_efficiency = unique_positions / len(move_history)
        print(f"DEBUG exploration_efficiency:", exploration_efficiency)

        # 3. Movement Consistency: avoid excessive backtracking
        backtrack_penalty = self._calculate_backtrack_penalty(move_history)
        print(f"DEBUG backtrack_penalty:", backtrack_penalty)

        # 4. Dead End Learning: ability to mark and avoid dead ends
        marked_dead_ends = getattr(maze_wrapper, 'marked_dead_ends', set())
        dead_end_revisits = self._count_dead_end_revisits(move_history, marked_dead_ends)
        dead_end_efficiency = 1.0 - (dead_end_revisits / max(len(move_history), 1))
        print(f"DEBUG dead_end_efficiency:", dead_end_efficiency)

        # 5. Recent Progress: is the agent making progress in recent moves?
        recent_progress = self._calculate_recent_progress(move_history)
        print(f"DEBUG recent_progress:", recent_progress)

        total_cost = (
            0.20 * (1 - move_success_rate) +        # Penalize failed moves
            0.20 * (1 - exploration_efficiency) +   # Penalize redundant exploration
            0.20 * backtrack_penalty +              # Penalize excessive backtracking
            0.20 * (1 - dead_end_efficiency) +      # Penalize revisiting marked dead ends
            0.20 * (1 - recent_progress)            # Penalize lack of recent progress
        )
        print(f"DEBUG pragmatic_cost:", total_cost)
        
        detailed_result = {
            "total_cost": min(total_cost, 1.0),
            "move_success_rate": move_success_rate,
            "exploration_efficiency": exploration_efficiency,
            "backtrack_penalty": backtrack_penalty,
            "dead_end_efficiency": dead_end_efficiency,
            "recent_progress": recent_progress
        }

        return detailed_result["total_cost"]


#####
# HELPER FUNCTIONS EPISTEMIC UNCERTAINTY
#####

    def _calculate_token_entropy(self, messages: List[AnyMessage]) -> float:
        """Calculate entropy of token distribution in messages"""
        if not messages:
            return 0.0
        
        print(f"TOKEN ENTROPY INPUT DEBUG: ", messages)
        
        # Parse messages to extract only essential content
        parser = MessageContentParser()
        essential_contents = parser.parse_messages(messages)
        
        print(f"TOKEN ENTROPY INPUT DEBUG (parsed): {essential_contents}")
        
        all_tokens = []
        for content in essential_contents:  # Use parsed content instead of raw messages
            # Remove punctuation and split into tokens
            clean_content = re.sub(r'[^\w\s]', ' ', content.lower())
            tokens = clean_content.split()
            all_tokens.extend(tokens)
        
        if not all_tokens:
            return 0.0
        
        # Calculate Shannon entropy
        token_counts = Counter(all_tokens)
        total_tokens = len(all_tokens)
        
        entropy = 0.0
        for count in token_counts.values():
            probability = count / total_tokens
            entropy -= probability * math.log2(probability)
        
        # Normalize by maximum possible entropy
        max_entropy = math.log2(len(token_counts)) if len(token_counts) > 1 else 1.0
        return entropy / max_entropy
    
####   
## HELPER FUNCTIONS PRAGMATIC COST
####

    def _calculate_backtrack_penalty(self, move_history: List[Tuple[int, int]]) -> float:
        """Calculate penalty for excessive backtracking and oscillation patterns"""
        if len(move_history) < 3:
            return 0.0
        
        # Count how often agent returns to previously visited positions
        position_visits = {}
        backtrack_count = 0
        
        for pos in move_history:
            if pos in position_visits:
                backtrack_count += 1
            position_visits[pos] = position_visits.get(pos, 0) + 1
        
        # Normalize by total moves
        backtrack_ratio = backtrack_count / len(move_history)
        
        # CRITICAL: Add oscillation penalty - detect A→B→A→B patterns
        oscillation_penalty = self._detect_oscillation_patterns(move_history)
        
        # DEBUG: Print backtrack analysis
        print(f"DEBUG backtrack_ratio: {backtrack_ratio:.3f}, oscillation_penalty: {oscillation_penalty:.3f}")
        
        # Combine backtrack and oscillation penalties (oscillation is worse)
        total_penalty = min(backtrack_ratio + (oscillation_penalty * 1.5), 1.0)
        print(f"DEBUG total_backtrack_penalty: {total_penalty:.3f}")
        return total_penalty
    
    def _detect_oscillation_patterns(self, move_history: List[Tuple[int, int]]) -> float:
        """Detect oscillation patterns like A→B→A→B and return penalty score"""
        if len(move_history) < 4:
            return 0.0
        
        oscillation_count = 0
        recent_moves = move_history[-10:]  # Look at last 10 moves
        
        # Check for A→B→A→B patterns in recent moves
        for i in range(len(recent_moves) - 3):
            if (recent_moves[i] == recent_moves[i+2] and 
                recent_moves[i+1] == recent_moves[i+3] and
                recent_moves[i] != recent_moves[i+1]):
                oscillation_count += 1
        
        # Check for A→B→A patterns (shorter oscillation)
        for i in range(len(recent_moves) - 2):
            if (recent_moves[i] == recent_moves[i+2] and
                recent_moves[i] != recent_moves[i+1]):
                oscillation_count += 0.5  # Smaller penalty for shorter oscillation
        
        # Normalize by recent moves length
        oscillation_ratio = oscillation_count / max(len(recent_moves) - 2, 1)
        return min(oscillation_ratio, 1.0)
    
    def _count_dead_end_revisits(self, move_history: List[Tuple[int, int]], 
                                marked_dead_ends: set) -> int:
        """Count how many times agent visited positions marked as dead ends"""
        if not marked_dead_ends:
            return 0
        
        revisit_count = 0
        for pos in move_history:
            if pos in marked_dead_ends:
                revisit_count += 1
        
        return revisit_count
    
    def _calculate_recent_progress(self, move_history: List[Tuple[int, int]]) -> float:
        """Calculate how much progress agent has made in recent moves"""
        if len(move_history) < self.rolling_window:
            return 1.0  # Give benefit of doubt for new agents
        
        # Look at last N moves
        rolling_window = -self.rolling_window * 2
        recent_moves = move_history[rolling_window:]
        unique_recent = len(set(recent_moves))
        
        # Progress = ratio of unique positions in recent moves
        recent_progress = unique_recent / len(recent_moves)
        return recent_progress