#!/usr/bin/env python3
"""
MIP Coordinator for managing pattern-based linearization.
"""

import re
from src.agents.pattern_reformulation_agents import PatternReformulationAgents
from src.core.nonlinear_detector import NonLinearPatternExtractor


class MIPCoordinator:
    """
    Enhanced pattern-based coordinator that processes ALL detected non-linear patterns.
    Applies multiple specialized agents and tracks all reformulations with clear marking.
    """
    
    def __init__(self, llm_model="o3", verbose=True):
        """
        Initialize the MIP coordinator.
        
        Args:
            llm_model: Either a string (model name) or LLMConfig object
            verbose (bool): Whether to print detailed progress information
        """
        self.llm_model = llm_model
        self.verbose = verbose
        self.pattern_extractor = NonLinearPatternExtractor(llm_model)
        self.reformulation_agents = PatternReformulationAgents(llm_model)
    
    def coordinate_reformulation(self, latex_model: str, extracted_patterns_result: str, 
                               problem_id=None, parameters=None, parameter_context=""):
        """
        Main coordination method for pattern-based MIP reformulation.
        
        Args:
            latex_model: The original LaTeX model
            extracted_patterns_result: Raw result from pattern extraction
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
            
        Returns:
            str: Final reformulated model with all non-linearities linearized and marked
        """
        # Parse the extracted patterns
        parsed_patterns = self.pattern_extractor.parse_patterns(extracted_patterns_result)
        
        if not parsed_patterns['has_nonlinearities']:
            return "No non-linearities requiring reformulation were identified."
        
        # Initialize tracking
        current_model = latex_model
        reformulation_log = []
        reformulation_details = []
        agents_applied = []
        
        # Count total patterns to process
        total_patterns = self._count_total_patterns(parsed_patterns)
        
        self._print_initial_summary(parsed_patterns, total_patterns)
        
        # Process only the pattern types that were actually detected
        if parsed_patterns['bilinear_patterns']:
            current_model = self._process_bilinear_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        if parsed_patterns['min_patterns']:
            current_model = self._process_min_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        if parsed_patterns['max_patterns']:
            current_model = self._process_max_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        if parsed_patterns['absolute_patterns']:
            current_model = self._process_absolute_value_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        if parsed_patterns['quotient_patterns']:
            current_model = self._process_quotient_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        if parsed_patterns['monotone_transformation_patterns']:
            current_model = self._process_monotone_transformation_patterns(
                current_model, parsed_patterns, reformulation_log, 
                reformulation_details, agents_applied, problem_id, parameters, parameter_context
            )
        
        # Create comprehensive summary
        summary = self._create_reformulation_summary(
            total_patterns, reformulation_details, agents_applied, 
            reformulation_log, current_model
        )
        
        return summary
    
    def _count_total_patterns(self, parsed_patterns):
        """Count total patterns to process."""
        return (len(parsed_patterns['bilinear_patterns']) + 
                len(parsed_patterns['min_patterns']) + 
                len(parsed_patterns['max_patterns']) + 
                len(parsed_patterns['absolute_patterns']) +
                len(parsed_patterns['quotient_patterns']) +
                len(parsed_patterns['monotone_transformation_patterns']))
    
    def _print_initial_summary(self, parsed_patterns, total_patterns):
        """Print initial summary of detected patterns."""
        if not self.verbose:
            return
            
        print(f"  📊 Found {total_patterns} non-linear patterns to reformulate:")
        print(f"    - {len(parsed_patterns['bilinear_patterns'])} bilinear patterns")
        print(f"    - {len(parsed_patterns['min_patterns'])} min patterns") 
        print(f"    - {len(parsed_patterns['max_patterns'])} max patterns") 
        print(f"    - {len(parsed_patterns['absolute_patterns'])} absolute value patterns")
        print(f"    - {len(parsed_patterns['quotient_patterns'])} quotient patterns")
        print(f"    - {len(parsed_patterns['monotone_transformation_patterns'])} monotone transformation patterns")
        
        # Debug: Show all detected patterns
        if parsed_patterns['bilinear_patterns']:
            print(f"  🔍 Bilinear patterns detected:")
            for i, pattern in enumerate(parsed_patterns['bilinear_patterns'], 1):
                print(f"    {i}. '{pattern}'")
        if parsed_patterns['absolute_patterns']:
            print(f"  🔍 Absolute value patterns detected:")
            for i, pattern in enumerate(parsed_patterns['absolute_patterns'], 1):
                print(f"    {i}. '{pattern}'")
        if parsed_patterns['quotient_patterns']:
            print(f"  🔍 Quotient patterns detected:")
            for i, pattern in enumerate(parsed_patterns['quotient_patterns'], 1):
                print(f"    {i}. '{pattern}'")
        if parsed_patterns['monotone_transformation_patterns']:
            print(f"  🔍 Monotone transformation patterns detected:")
            for i, pattern in enumerate(parsed_patterns['monotone_transformation_patterns'], 1):
                print(f"    {i}. '{pattern}'")
        
        print(f"  🚀 Applying specialized reformulation agents...")
    
    def _process_bilinear_patterns(self, current_model, parsed_patterns, 
                                 reformulation_log, reformulation_details, 
                                 agents_applied, problem_id, parameters, parameter_context):
        """Process all bilinear patterns."""
        if self.verbose:
            print(f"\n  🔧 BILINEAR REFORMULATION AGENT")
        
        agents_applied.append("Bilinear Pattern Reformulation Agent")
        
        for i, bilinear_pattern in enumerate(parsed_patterns['bilinear_patterns']):
            if self.verbose:
                print(f"    🔄 Processing bilinear pattern {i+1}/{len(parsed_patterns['bilinear_patterns'])}: {bilinear_pattern}")
            
            # Process the pattern
            result = self._process_single_pattern(
                'bilinear', bilinear_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Bilinear reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from bilinear reformulation")
        
        return current_model
    
    def _process_min_patterns(self, current_model, parsed_patterns, 
                            reformulation_log, reformulation_details, 
                            agents_applied, problem_id, parameters, parameter_context):
        """Process all min patterns."""
        if self.verbose:
            print(f"\n  🔧 MIN REFORMULATION AGENT")
        
        agents_applied.append("Min Pattern Reformulation Agent")
        
        for i, min_pattern in enumerate(parsed_patterns['min_patterns']):
            if self.verbose:
                print(f"    🔄 Processing min pattern {i+1}/{len(parsed_patterns['min_patterns'])}: {min_pattern}")
            
            # Process the pattern
            result = self._process_single_pattern(
                'min', min_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Min reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from min reformulation")
        
        return current_model
    
    def _process_max_patterns(self, current_model, parsed_patterns, 
                            reformulation_log, reformulation_details, 
                            agents_applied, problem_id, parameters, parameter_context):
        """Process all max patterns."""
        if self.verbose:
            print(f"\n  🔧 MAX REFORMULATION AGENT")
        
        agents_applied.append("Max Pattern Reformulation Agent")
        
        for i, max_pattern in enumerate(parsed_patterns['max_patterns']):
            if self.verbose:
                print(f"    🔄 Processing max pattern {i+1}/{len(parsed_patterns['max_patterns'])}: {max_pattern}")
            
            # Process the pattern
            result = self._process_single_pattern(
                'max', max_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Max reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from max reformulation")
        
        return current_model
    
    def _process_absolute_value_patterns(self, current_model, parsed_patterns, 
                                       reformulation_log, reformulation_details, 
                                       agents_applied, problem_id, parameters, parameter_context):
        """Process all absolute value patterns."""
        if self.verbose:
            print(f"\n  🔧 ABSOLUTE VALUE REFORMULATION AGENT")
        
        agents_applied.append("Absolute Value Pattern Reformulation Agent")
        
        for i, absolute_pattern in enumerate(parsed_patterns['absolute_patterns']):
            if self.verbose:
                print(f"    🔄 Processing absolute value pattern {i+1}/{len(parsed_patterns['absolute_patterns'])}: {absolute_pattern}")
            
            # Safety check: Skip invalid patterns
            if self._is_invalid_pattern(absolute_pattern):
                self._handle_invalid_pattern('absolute_value', i+1, absolute_pattern, 
                                           reformulation_log, reformulation_details)
                continue
            
            # Debug: Show what model is being passed to the agent
            if self.verbose:
                print(f"      🔍 Input model length: {len(current_model)} characters")
                print(f"      🔍 Input model preview: {current_model[:200]}...")
            
            # Process the pattern
            result = self._process_single_pattern(
                'absolute_value', absolute_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Absolute value reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from absolute value reformulation")
        
        return current_model
    
    def _process_quotient_patterns(self, current_model, parsed_patterns, 
                                 reformulation_log, reformulation_details, 
                                 agents_applied, problem_id, parameters, parameter_context):
        """Process all quotient patterns."""
        if self.verbose:
            print(f"\n  🔧 QUOTIENT REFORMULATION AGENT")
        
        agents_applied.append("Quotient Pattern Reformulation Agent")
        
        for i, quotient_pattern in enumerate(parsed_patterns['quotient_patterns']):
            if self.verbose:
                print(f"    🔄 Processing quotient pattern {i+1}/{len(parsed_patterns['quotient_patterns'])}: {quotient_pattern}")
            
            # Safety check: Skip invalid patterns
            if self._is_invalid_pattern(quotient_pattern):
                self._handle_invalid_pattern('quotient', i+1, quotient_pattern, 
                                           reformulation_log, reformulation_details)
                continue
            
            # Process the pattern
            result = self._process_single_pattern(
                'quotient', quotient_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Quotient reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from quotient reformulation")
        
        return current_model
    
    def _process_monotone_transformation_patterns(self, current_model, parsed_patterns, 
                                                reformulation_log, reformulation_details, 
                                                agents_applied, problem_id, parameters, parameter_context):
        """Process all monotone transformation patterns."""
        if self.verbose:
            print(f"\n  🔧 MONOTONE TRANSFORMATION REFORMULATION AGENT")
        
        agents_applied.append("Monotone Transformation Pattern Reformulation Agent")
        
        for i, monotone_pattern in enumerate(parsed_patterns['monotone_transformation_patterns']):
            if self.verbose:
                print(f"    🔄 Processing monotone transformation pattern {i+1}/{len(parsed_patterns['monotone_transformation_patterns'])}: {monotone_pattern}")
            
            # Safety check: Skip invalid patterns
            if self._is_invalid_pattern(monotone_pattern):
                self._handle_invalid_pattern('monotone_transformation', i+1, monotone_pattern, 
                                           reformulation_log, reformulation_details)
                continue
            
            # Process the pattern
            result = self._process_single_pattern(
                'monotone_transformation', monotone_pattern, current_model, 
                problem_id, parameters, i+1, parameter_context
            )
            
            if result['success']:
                current_model = result['updated_model']
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ✅ Monotone transformation reformulation applied successfully")
            else:
                reformulation_log.append(result['log_entry'])
                reformulation_details.append(result['details'])
                
                if self.verbose:
                    print(f"      ❌ Failed to extract updated model from monotone transformation reformulation")
        
        return current_model
    
    def _is_invalid_pattern(self, pattern):
        """Check if a pattern is invalid."""
        return (not pattern or 
                pattern.strip() in ['```', '```plaintext', 'plaintext'])
    
    def _handle_invalid_pattern(self, pattern_type, pattern_num, pattern, 
                               reformulation_log, reformulation_details):
        """Handle invalid patterns."""
        if self.verbose:
            print(f"      ⚠️ Skipping invalid pattern: '{pattern}'")
        
        reformulation_log.append(f"⚠️ {pattern_type.capitalize()} pattern {pattern_num}: '{pattern}' -> Skipped (invalid pattern)")
        reformulation_details.append({
            'type': pattern_type,
            'pattern': pattern,
            'technique': 'Skipped',
            'status': 'skipped'
        })
    
    def _process_single_pattern(self, pattern_type, pattern, current_model, 
                              problem_id, parameters, pattern_num, parameter_context=""):
        """Process a single pattern and return the result."""
        try:
            # Get the appropriate agent
            agent_method = getattr(self.reformulation_agents, f"{pattern_type}_pattern_reformulation_agent")
            
            # Run the agent
            reformulation = agent_method(current_model, pattern, problem_id, parameters, parameter_context)
            
            # Extract the updated model
            updated_model = self._extract_updated_model(reformulation)
            final_model = self._add_reformulation_marking(pattern_type, pattern, updated_model)
            
            return {
                'success': True,
                'updated_model': final_model,
                'log_entry': f"✅ {pattern_type.capitalize()} pattern {pattern_num}: {pattern} -> Reformulation applied",
                'details': {
                    'type': pattern_type,
                    'pattern': pattern,
                    'technique': 'Reformulation applied',
                    'status': 'success'
                }
            }
                
        except Exception as agent_error:
            if self.verbose:
                print(f"      ❌ Agent execution failed: {str(agent_error)}")
            
            return {
                'success': False,
                'updated_model': current_model,
                'log_entry': f"❌ {pattern_type.capitalize()} pattern {pattern_num}: {pattern} -> Agent execution failed: {str(agent_error)}",
                'details': {
                    'type': pattern_type,
                    'pattern': pattern,
                    'technique': 'Agent execution failed',
                    'status': 'failed'
                }
            }
    
    def _extract_updated_model(self, reformulation):
        """Extract the updated model from reformulation response."""
        # Try different possible markers
        markers = ["UPDATED MODEL:", "## UPDATED MODEL", "UPDATED MODEL"]
        
        for marker in markers:
            if marker in reformulation:
                updated_model = reformulation.split(marker)[1].strip()
                
                # Clean any remaining markdown artifacts from the updated model
                if updated_model.startswith('```latex'):
                    updated_model = updated_model[8:]  # Remove ```latex
                if updated_model.endswith('```'):
                    updated_model = updated_model[:-3]  # Remove trailing ```
                
                return updated_model.strip()
        
        # If no marker found, raise an error
        raise ValueError(f"Could not find UPDATED MODEL marker in response. Available markers tried: {markers}")
    
    def _add_reformulation_marking(self, pattern_type, pattern, updated_model):
        """Add reformulation marking to the model."""
        agent_names = {
            'bilinear': 'Bilinear Pattern Reformulation Agent',
            'min': 'Min Pattern Reformulation Agent',
            'max': 'Max Pattern Reformulation Agent',
            'absolute_value': 'Absolute Value Pattern Reformulation Agent',
            'quotient': 'Quotient Pattern Reformulation Agent',
            'monotone_transformation': 'Monotone Transformation Pattern Reformulation Agent'
        }
        
        agent_name = agent_names.get(pattern_type, f'{pattern_type.capitalize()} Pattern Reformulation Agent')
        
        return f"""% {pattern_type.upper()}_PATTERNS REFORMULATION APPLIED: Reformulation applied
% Pattern: {pattern}
% Agent: {agent_name}

{updated_model}
"""
    
    def _get_technique_name(self, pattern_type):
        """Get the technique name for a pattern type."""
        return 'Reformulation applied'
    
    def _create_reformulation_summary(self, total_patterns, reformulation_details, 
                                    agents_applied, reformulation_log, final_model):
        """Create comprehensive reformulation summary."""
        successful_reformulations = [d for d in reformulation_details if d['status'] == 'success']
        failed_reformulations = [d for d in reformulation_details if d['status'] == 'failed']
        
        if self.verbose:
            print(f"\n  ✅ Reformulation completed:")
            print(f"    - {len(successful_reformulations)}/{total_patterns} patterns successfully linearized")
            print(f"    - {len(agents_applied)} specialized agents applied")
            if failed_reformulations:
                print(f"    - ⚠️ {len(failed_reformulations)} patterns failed to linearize")
        
        # Add comprehensive header to final model
        reformulation_header = f"""% ===============================================================================
% MIP REFORMULATION SUMMARY
% ===============================================================================
% Total Patterns Detected: {total_patterns}
% Successful Reformulations: {len(successful_reformulations)}
% Failed Reformulations: {len(failed_reformulations)}
%
% AGENTS APPLIED:
% {chr(10).join([f'% - {agent}' for agent in agents_applied])}
%
% REFORMULATION TECHNIQUES USED:
% {chr(10).join([f'% - {d["type"].upper()}: {d["technique"]} ({d["pattern"]})' for d in successful_reformulations])}
% ===============================================================================

"""
        
        final_model_with_header = reformulation_header + final_model
        
        # Create detailed summary
        summary = f"""
COMPREHENSIVE PATTERN-BASED MIP REFORMULATION COMPLETED

REFORMULATION STATISTICS:
- Total patterns detected: {total_patterns}
- Successful reformulations: {len(successful_reformulations)}
- Failed reformulations: {len(failed_reformulations)}

AGENTS APPLIED:
{chr(10).join([f"- {agent}" for agent in agents_applied])}

DETAILED REFORMULATION LOG:
{chr(10).join(reformulation_log)}

FINAL LINEARIZED MODEL WITH REFORMULATION MARKINGS:
{final_model_with_header}
"""
        
        return summary 