#!/usr/bin/env python3
"""
Non-linear pattern detection and parsing for optimization models.
"""

from src.utils.llm_model_manager import create_llm
from src.utils.prompt_template_manager import PromptTemplateManager


class NonLinearPatternExtractor:
    """
    Class responsible for extracting and parsing non-linear patterns from LaTeX models.
    """
    
    def __init__(self, llm_model="o3"):
        """
        Initialize the non-linear pattern extractor.
        
        Args:
            llm_model: Either a string (model name) or LLMConfig object
        """
        self.llm_model = llm_model
        self.llm = create_llm(llm_model)
        self.template_manager = PromptTemplateManager("src/configs/prompts")
    
    def extract_patterns(self, latex_model: str, parameter_context="", param_info=""):
        """
        Extract non-linear term patterns from a LaTeX model and categorize them.
        Groups similar indexed terms into patterns for efficient processing.
        
        Args:
            latex_model: The LaTeX model to analyze
            parameter_context: Context about parameters (legacy)
            param_info: Concrete parameter values (should be ignored during detection)
        """
        # Log the input for debugging
        print(f"🔍 PATTERN DETECTION INPUT:")
        print(f"   Parameter Context: {parameter_context}")
        print(f"   Param Info: {param_info[:200]}..." if len(param_info) > 200 else f"   Param Info: {param_info}")
        print(f"   LaTeX Model: {latex_model[:300]}...")
        
        prompt = self.template_manager.format_template(
            "pattern_detection_prompt",
            parameter_context=parameter_context,
            param_info=param_info,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        
        return response.content.strip()
    
    def parse_patterns(self, extraction_result: str):
        """
        Parse the pattern extraction result into a structured dictionary.
        
        Returns:
            dict: {
                'has_nonlinearities': bool,
                'bilinear_patterns': list,
                'min_patterns': list,
                'max_patterns': list,
                'absolute_patterns': list,
                'quotient_patterns': list,
                'monotone_transformation_patterns': list
            }
        """
        result = {
            'has_nonlinearities': False,
            'bilinear_patterns': [],
            'min_patterns': [],
            'max_patterns': [],
            'absolute_patterns': [],
            'quotient_patterns': [],
            'monotone_transformation_patterns': []
        }
        
        # Clean the extraction result from markdown code fences
        cleaned_result = extraction_result
        if '```plaintext' in cleaned_result:
            cleaned_result = cleaned_result.replace('```plaintext', '')
        if cleaned_result.strip().endswith('```'):
            cleaned_result = cleaned_result.strip()[:-3]
        
        # Check if non-linearities were detected
        if ("✅ NON-LINEARITIES DETECTED: YES" in cleaned_result or
            "✅ NON-LINEARITIES DETECTED: [YES]" in cleaned_result):
            result['has_nonlinearities'] = True
        
        # Parse each category
        sections = cleaned_result.split('\n')
        current_category = None
        
        for line in sections:
            line = line.strip()
            
            # Skip markdown artifacts and empty lines
            if not line or line.startswith('```') or line == 'plaintext':
                continue
                
            if line.startswith('BILINEAR_PATTERNS:'):
                current_category = 'bilinear_patterns'
            elif line.startswith('MIN_PATTERNS:'):
                current_category = 'min_patterns'
            elif line.startswith('MAX_PATTERNS:'):
                current_category = 'max_patterns'
            elif line.startswith('ABSOLUTE_PATTERNS:'):
                current_category = 'absolute_patterns'
            elif line.startswith('QUOTIENT_PATTERNS:'):
                current_category = 'quotient_patterns'
            elif line.startswith('MONOTONE_TRANSFORMATION_PATTERNS:'):
                current_category = 'monotone_transformation_patterns'
            elif line.startswith('OTHER_PATTERNS:'):
                current_category = 'other_patterns'
            elif (current_category and line and 
                  not line.startswith('✅') and not line.startswith('[') and
                  not line.startswith('```') and line != 'plaintext'):
                
                # Check if this line indicates "no patterns" (various formats)
                if (line == 'NONE' or 
                    line == '* NONE' or 
                    line == 'NONE' or
                    line.startswith('* NONE') or
                    line.lower() == 'none' or
                    line.lower() == '* none' or
                    line == '0' or
                    line.lower() == '* 0'
                    ):
                    # Skip this line - it indicates no patterns in this category
                    continue
                
                # This is a valid pattern under the current category
                result[current_category].append(line)
        
        return result
    
    def extract_and_parse(self, latex_model: str, parameter_context="", param_info=""):
        """
        Extract and parse non-linear patterns in one step.
        
        Args:
            latex_model: The LaTeX model to analyze
            parameter_context: Context about parameters (legacy)
            param_info: Concrete parameter values (should be ignored during detection)
            s
        Returns:
            dict: Parsed pattern results
        """
        raw_result = self.extract_patterns(latex_model, parameter_context, param_info)
        parsed_result = self.parse_patterns(raw_result)
        
        return parsed_result
    