import os
import math
import threading
import time
from typing import List, Dict, Any
import sys

from core.clients.openrouter_client import OpenRouterClient
# REMOVED: from core.components.inner_monologue_manager import MonologueManager
# REMOVED: from core.components.cognitive_controller import CognitiveController
from core.ablations.cognitive_controller_ablation import CognitiveControllerAblation
from core.components.talker import Talker

class MirrorCognitiveOnly:
    """
    Ablation variant of MIRROR that removes the Inner Monologue Manager/threads,
    keeping only the Cognitive Controller and Talker.
    """

    def __init__(self, 
                 api_key: str = None,
                 model: str = "openai/gpt-4o"
                 ):
        """
        Initialize the Mirror system with only Cognitive Controller (no threads).

        Args:
            api_key: OpenRouter API key (defaults to environment variable)
            model: The OpenRouter model identifier string to use (e.g., "openai/gpt-4o")
        """
        # Get API key from environment if not provided
        if not api_key:
            api_key = os.environ.get("OPENROUTER_API_KEY")
            if not api_key:
                raise ValueError("OPENROUTER_API_KEY environment variable must be set")
        
        # Store the model identifier
        self.model = model
        print(f"INFO: Initializing Mirror (Cognitive Only - No Threads) with model: {self.model}")
        sys.stdout.flush()

        # Initialize conversation history
        self.conversation_history = []
        
        # Create OpenRouter client
        self.client = OpenRouterClient(api_key)
        
        # REMOVED: Create monologue manager
        
        # Create cognitive controller ablation version, passing the model
        self.cognitive_controller = CognitiveControllerAblation(self.client, model=self.model)
        
        # Create talker, passing the model
        self.talker = Talker(
            client=self.client,
            model=self.model
        )
        
        # Initialize turn counter
        self.turn_count = 0
        
        # Initialize insights
        self.current_insights = None
        
        # Add tracking for cognitive controller outputs
        self.last_consolidated_narrative = None # Store the output from CognitiveController
        
        # Add thread synchronization mechanisms
        self.insights_lock = threading.Lock()
        self.background_thread = None
        self.background_thread_active = False
        self.background_thread_completed = threading.Event()
        self.background_thread_completed.set()  # Initially set as completed

    def reset_conversation(self):
        """
        Resets all internal state to start a fresh conversation.
        """
        print("INFO: Performing reset of Mirror (Cognitive Only) conversation state...")
        sys.stdout.flush()
        
        # Handle any active background threads
        if self.background_thread and self.background_thread.is_alive():
            print("  - Waiting for active background thread to complete")
            sys.stdout.flush()
            thread_completed = self.background_thread_completed.wait(timeout=100.0)
            if not thread_completed:
                print("  - WARNING: Background thread did not complete within timeout")
                sys.stdout.flush()
                self.background_thread_active = False
        
        # Acquire lock to ensure thread safety during reset
        with self.insights_lock:
            # Reset conversation state
            print("  - Resetting conversation history and turn counter")
            sys.stdout.flush()
            self.conversation_history = []
            self.turn_count = 0
            
            # Reset all insight and narrative state
            print("  - Clearing all insights and narrative state")
            sys.stdout.flush()
            self.current_insights = None
            self.last_consolidated_narrative = None
            
            # Reset Cognitive Controller using its reset method
            print("  - Resetting CognitiveController")
            sys.stdout.flush()
            self.cognitive_controller.reset()
            
            # Reset Talker
            print("  - Resetting Talker state")
            sys.stdout.flush()
            if hasattr(self.talker, 'last_response'):
                self.talker.last_response = None
            
            # Clean up threading resources
            print("  - Resetting thread state")
            sys.stdout.flush()
            self.background_thread = None
            self.background_thread_active = False
            self.background_thread_completed.set()
        
        print("INFO: Reset completed successfully")
        sys.stdout.flush()

    def truncate_conversation_history(self, conversation: List[Dict[str, str]], max_tokens: int = 20000) -> List[Dict[str, str]]:
        """Truncate conversation history to fit within token limits."""
        if not conversation:
            return conversation
        
        # Estimate token count (rough approximation: 4 chars ≈ 1 token)
        def estimate_tokens(messages: List[Dict[str, str]]) -> int:
            return sum(len(msg.get("content", "")) // 4 + 20 for msg in messages)
    
        # If conversation is already within limits, return as is
        estimated_tokens = estimate_tokens(conversation)
        if estimated_tokens <= max_tokens:
            print(f"Conversation history within limits ({estimated_tokens} tokens)")
            sys.stdout.flush()
            return conversation
    
        # Extract system messages (always preserve these)
        system_messages = [msg for msg in conversation if msg.get("role") == "system"]
    
        # Find and preserve the first user message (contains document context)
        first_user_idx = next((i for i, msg in enumerate(conversation) 
                            if msg.get("role") == "user"), None)
        first_user_message = [conversation[first_user_idx]] if first_user_idx is not None else []
    
        # Start with essential messages (system + first user message)
        essential_messages = system_messages + first_user_message
    
        # Try keeping essential messages + last N messages
        for n_recent in [10, 6, 4, 2]:
            # Skip the first user message if it's within the recent messages
            if first_user_idx is not None and first_user_idx >= len(conversation) - n_recent:
                recent_messages = conversation[-n_recent:]
            else:
                recent_messages = [msg for idx, msg in enumerate(conversation[-n_recent:]) 
                              if not (msg.get("role") == "user" and 
                                        first_user_idx is not None and 
                                        idx + len(conversation) - n_recent == first_user_idx)]
        
            truncated = essential_messages + recent_messages
        
            if estimate_tokens(truncated) <= max_tokens:
                print(f"Truncated to essential messages + last {n_recent} messages ({estimate_tokens(truncated)} tokens)")
                sys.stdout.flush()
                return truncated
    
        # Absolute last resort
        return system_messages + first_user_message

    def process_user_input(self, user_input: str) -> str:
        """
        Process user input through the Mirror architecture (Cognitive Only variant).

        Args:
            user_input: The user's message

        Returns:
            The AI's response
        """
        # Wait for any existing background thread to complete
        if self.background_thread and self.background_thread.is_alive():
            print(f"Waiting for previous background thinking thread to complete...")
            sys.stdout.flush()
            self.background_thread_completed.wait(timeout=100.0)
            
        # Increment turn counter
        self.turn_count += 1
        print(f"Processing turn {self.turn_count}")
        sys.stdout.flush()

        # Add user input to conversation history
        self.conversation_history.append({"role": "user", "content": user_input})

        if self.turn_count > 1:
            # Use insights from previous turn's thinking
            print(f"Turn {self.turn_count}: Applying insights from previous turn's thinking")
            sys.stdout.flush()
            with self.insights_lock:
                current_turn_insights = self.current_insights
        else:
            print("Turn 1: Generating immediate response (background thinking will start after)")
            sys.stdout.flush()
            current_turn_insights = None

        # Generate final response with talker (using full conversation history)
        try:
            result = self.talker.respond(
                self.conversation_history, 
                current_turn_insights
            )
            response = result
        except Exception as e:
            print(f"Error during response generation: {e}")
            sys.stdout.flush()
            response = f"I apologize, but I encountered an error: {str(e)}"

        # Add assistant response to conversation history
        self.conversation_history.append({"role": "assistant", "content": response})
        
        # Start background thinking in a background thread
        self.background_thread_completed.clear()
        self.background_thread_active = True
            
        # Create and start a new background thread
        self.background_thread = threading.Thread(
            target=self._background_thinking_wrapper,
            args=(self.conversation_history,),
            daemon=True
        )
        self.background_thread.start()
        print("Started background thinking thread (Cognitive Only)")

        sys.stdout.flush()
        return response
    
    def _background_thinking_wrapper(self, history_snapshot: List[Dict[str, str]]):
        """Wrapper function for background thinking that ensures proper cleanup."""
        try:
            self.process_background_thinking(history_snapshot)
        finally:
            self.background_thread_active = False
            self.background_thread_completed.set()
            print("Background thinking completed (Cognitive Only)")
            sys.stdout.flush()
    
    def process_background_thinking(self, conversation_history_snapshot: List[Dict[str, str]], timeout=120.0):
        """
        Process background thinking using only Cognitive Controller (no threads).
        
        Args:
            conversation_history_snapshot: The conversation history state.
            timeout: Maximum time to wait for processing (in seconds)
        """
        start_time = time.time()
        print(f"Starting background thinking (Cognitive Only, no threads) (timeout: {timeout}s)")
        sys.stdout.flush()
        
        try:
            # Truncate conversation history
            truncated_conversation = self.truncate_conversation_history(conversation_history_snapshot)
            
            # MODIFIED: Use the ablation cognitive controller that analyzes conversation directly
            # Pass empty thread_results (for API compatibility) and the conversation history
            thread_results = []
            
            # Display notice about bypassing threads
            print("\n╔══════════════════════════════════════════════════════════════╗")
            print("║         COGNITIVE-ONLY ABLATION (NO THREADS)                 ║")
            print("╚══════════════════════════════════════════════════════════════╝")
            print("\n⚠️  Bypassing Inner Monologue Threads - Analyzing conversation directly")
            print(f"📝 Conversation length: {len(truncated_conversation)} turns")
            print("─" * 60 + "\n")
            sys.stdout.flush()
            
            # Consolidate insights with cognitive controller ablation version
            print("Running cognitive controller consolidation (ablation mode)...")
            sys.stdout.flush()
            
            insights = self.cognitive_controller.consolidate(thread_results, truncated_conversation)
            
            # Store insights
            with self.insights_lock:
                self.current_insights = insights
                self.last_consolidated_narrative = insights
            
            processing_time = time.time() - start_time
            print(f"Background thinking (Cognitive Only) completed in {processing_time:.2f}s")
            sys.stdout.flush()
            
        except Exception as e:
            print(f"ERROR in background thinking: {e}")
            sys.stdout.flush()
            with self.insights_lock:
                self.current_insights = f"Background thinking error: {str(e)}"
                self.last_consolidated_narrative = f"ERROR: {str(e)}"
        
        total_time = time.time() - start_time
        print(f"Background thinking processing time: {total_time:.2f}s")
        sys.stdout.flush()

    def run_interactive(self):
        """Run the Mirror system in interactive mode."""
        print("Mirror (Cognitive Only - No Threads) system initialized! Type 'exit' to quit.")
        print("Note: This ablation variant uses only the Cognitive Controller without inner monologue threads.")
        print("The first response will be immediate, while subsequent responses will benefit from background processing.")
        print("-" * 80)
        sys.stdout.flush()
        
        while True:
            user_input = input("You: ")
            if user_input.lower() in ["exit", "quit", "bye"]:
                print("Goodbye!")
                sys.stdout.flush()
                break
            
            turn_number = self.turn_count + 1
            
            if turn_number == 1:
                print("Turn 1: Generating immediate response (background thinking will start after)")
                sys.stdout.flush()
            else:
                print(f"Turn {turn_number}: Applying insights from previous turn's thinking")
                sys.stdout.flush()
                
            # Get AI response first
            response = self.process_user_input(user_input)
            
            # Show the response immediately
            print("\nAI:", response)
            sys.stdout.flush()
            
            print("-" * 80) 
            sys.stdout.flush() 