import os
import sys
from typing import List, Dict, Any, Tuple, Union, Optional

# Add parent directory to Python path so core module can be found
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Import ablation variants
from core.ablations.mirror_ablation_cognitive_only import MirrorCognitiveOnly
from core.ablations.mirror_ablation_threads_only import MirrorThreadsOnly

class MirrorProviderAblation:
    """
    Model provider implementation for MIRROR ablation variants.
    This class adapts the MIRROR ablation architectures to work with the benchmark.
    """
    
    def __init__(self, api_key=None, model=None, temp=0.7, ablation_type="cognitive_only", **kwargs):
        """
        Initialize the MIRROR ablation provider.
        
        Args:
            api_key: API key for OpenRouter (optional, can use environment variable)
            model: Model name to use
            temp: Temperature for generation
            ablation_type: "cognitive_only" or "threads_only"
            **kwargs: Additional arguments
        """
        self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY")
        self.temperature = float(temp) if temp is not None else 0.7
        self.ablation_type = ablation_type
        
        # Fix model ID format if it has the openrouter/ prefix
        if model and model.startswith("openrouter/"):
            model = model[len("openrouter/"):]
            print(f"Stripped 'openrouter/' prefix from model ID. Using: {model}")

        # Initialize the appropriate ablation variant
        if ablation_type == "cognitive_only":
            print("Initializing MIRROR ablation: Cognitive Controller Only (no threads)")
            self.mirror = MirrorCognitiveOnly(api_key=self.api_key, model=model)
        elif ablation_type == "threads_only":
            print("Initializing MIRROR ablation: Threads Only (no cognitive controller)")
            self.mirror = MirrorThreadsOnly(api_key=self.api_key, model=model)
        else:
            raise ValueError(f"Unknown ablation type: {ablation_type}. Must be 'cognitive_only' or 'threads_only'")
        
        # Track if background thinking has been processed
        self.is_first_turn = True
        
        # Track conversation state for internal management
        self.conversation_history = []
        
    def generate_response(self, messages: List[Dict[str, str]]) -> str:
        """
        Generate a response using the MIRROR ablation architecture.
        
        Args:
            messages: List of conversation messages from benchmark
            
        Returns:
            Generated response string
        """
        # Extract the latest user message
        latest_user_message = ""
        for msg in reversed(messages):
            if msg.get("role") == "user":
                latest_user_message = msg.get("content", "")
                break
        
        if not latest_user_message:
            return "I didn't receive a user message to respond to."
        
        # Generate response using MIRROR ablation
        try:
            response = self.mirror.process_user_input(latest_user_message)
            self.is_first_turn = False
            return response
        except Exception as e:
            print(f"Error in MIRROR ablation ({self.ablation_type}) response generation: {e}")
            return f"I apologize, but I encountered an error while processing your message: {str(e)}"
    
    def generate_response_with_metadata(self, messages: List[Dict[str, str]]) -> Union[str, Tuple[str, Optional[float], Optional[bool], Optional[Dict], Optional[str]]]:
        """
        Generate a response with metadata using the MIRROR ablation architecture.
        
        Args:
            messages: List of conversation messages
            
        Returns:
            Tuple of (response, log_prob, monologue_active, monologue_output, consolidated_narrative)
        """
        # Generate basic response
        response = self.generate_response(messages)
        
        # Collect metadata based on ablation type
        if self.ablation_type == "cognitive_only":
            # For cognitive only, we have no monologue
            monologue_active = False
            monologue_output = None
            consolidated_narrative = getattr(self.mirror, 'last_consolidated_narrative', None)
        else:  # threads_only
            # For threads only, we have monologue but no consolidated narrative
            monologue_active = True  # Always active in threads-only mode
            monologue_output = getattr(self.mirror, 'last_thread_outputs', None)
            consolidated_narrative = None
        
        # Log prob not available in ablation variants
        log_prob = None
        
        return response, log_prob, monologue_active, monologue_output, consolidated_narrative
    
    def reset(self):
        """Reset the conversation state."""
        if hasattr(self.mirror, 'reset_conversation'):
            self.mirror.reset_conversation()
        self.is_first_turn = True
        self.conversation_history = [] 