#!/usr/bin/env python3
"""
Code Converter for converting LaTeX optimization models to Gurobi Python code.
"""

from src.utils.llm_model_manager import create_llm
from src.utils.prompt_template_manager import PromptTemplateManager
from gurobipy import *
import numpy as np

import json

class CodeConverter:
    """
    Class responsible for converting LaTeX mathematical models to executable Gurobi Python code.
    """
    
    def __init__(self, llm_model="o3"):
        """
        Initialize the code converter.
        
        Args:
            llm_model: Either a string (model name) or LLMConfig object
        """
        self.llm_model = llm_model
        self.llm = create_llm(llm_model)
        self.template_manager = PromptTemplateManager("src/configs/prompts")
    
    def convert_with_parameters(self, latex_model: str, problem_id=None, parameters=None, additional_context="", parameter_context=""):
        """
        Convert LaTeX model to Gurobi Python code with parameter integration.
        
        Args:
            latex_model (str): LaTeX mathematical model (may include reformulation markings)
            problem_id (str): Problem identifier for loading parameters
            parameters (dict): Pre-loaded parameters (optional)
            additional_context (str): Additional context for code generation (e.g., validation feedback)
            
        Returns:
            str: Clean, executable Gurobi Python code with reformulation comments
        """
        # Extract reformulation information from LaTeX comments
        reformulation_info = self._extract_reformulation_info(latex_model)
        
        # Format parameters for inclusion in the prompt
        param_info = ""
        sets_info = ""
        if parameters:
            param_info = f"""
AVAILABLE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these actual parameter values in your code instead of placeholders.
"""
            
            # Extract sets information for better code generation
            if 'sets' in parameters:
                sets_info = "\nSETS MAPPING:\n"
                for set_name, set_values in parameters['sets'].items():
                    sets_info += f"- Set {set_name}: {set_values}\n"
                    sets_info += f"  Use: for {set_name.lower()} in {set_values}\n"
                sets_info += "\nWhen you see LaTeX notation like \\forall_{{i \\in SET1}} or \\sum_{{i \\in SET1}}, map it to the concrete sets above.\n"
        
        # Add reformulation context if detected
        reformulation_context = ""
        if reformulation_info['has_reformulations']:
            reformulation_context = f"""
The LaTeX model contains the following MIP reformulations:
{chr(10).join([f"- {ref['type'].upper()}: {ref['technique']} (Pattern: {ref['pattern']})" for ref in reformulation_info['reformulations']])}

IMPORTANT: Include these reformulation details as comments in your Python code to document the linearization techniques applied.
"""
        else:
            reformulation_context = "No reformulations applied to this model."
        
        # Add additional context (e.g., validation feedback)
        if additional_context:
            reformulation_context += f"\n\nADDITIONAL CONTEXT:\n{additional_context}"

        prompt = self.template_manager.format_template(
            "code_conversion_prompt",
            param_info=param_info,
            sets_info=sets_info,
            reformulation_context=reformulation_context,
            parameter_context=parameter_context,
            latex_model=latex_model
        )

        response = self.llm.invoke(prompt)
        
        # Clean up the response to ensure it's executable Python code
        code = response.content.strip()
        
        # Remove any markdown code fences if they exist
        if code.startswith('```python'):
            code = code[9:]  # Remove ```python
        if code.startswith('```'):
            code = code[3:]   # Remove ```
        if code.endswith('```'):
            code = code[:-3]  # Remove trailing ```
        
        # Remove any leading/trailing whitespace
        code = code.strip()
        
        # Ensure the code starts with imports
        if not code.startswith('from gurobipy import'):
            code = 'from gurobipy import *\nimport numpy as np\n\n' + code
        
        # Add reformulation header to the code if reformulations were detected
        if reformulation_info['has_reformulations']:
            reformulation_header = f"""# ===============================================================================
# MIP REFORMULATION APPLIED
# ===============================================================================
# The following reformulation techniques have been applied to linearize non-linear terms:
{chr(10).join([f"# - {ref['type'].upper()}: {ref['technique']} for pattern '{ref['pattern']}'" for ref in reformulation_info['reformulations']])}
# ===============================================================================
"""
            # Add post-processing instructions for monotone transformations
            if reformulation_info.get('postprocessing_instructions'):
                reformulation_header += "# POST-PROCESSING INSTRUCTIONS FOR MONOTONE TRANSFORMATIONS:\n"
                for instr in reformulation_info['postprocessing_instructions']:
                    reformulation_header += f"# {instr}\n"
            # Insert header after imports
            lines = code.split('\n')
            import_lines = []
            code_lines = []
            in_imports = True
            
            for line in lines:
                if in_imports and (line.startswith('from ') or line.startswith('import ') or line.strip() == ''):
                    import_lines.append(line)
                else:
                    in_imports = False
                    code_lines.append(line)
            
            code = '\n'.join(import_lines) + '\n\n' + reformulation_header + '\n'.join(code_lines)
        
        return code
    
    def _extract_reformulation_info(self, latex_model: str):
        """
        Extract reformulation information from LaTeX model comments.
        
        Args:
            latex_model (str): LaTeX model with potential reformulation markings
            
        Returns:
            dict: Information about applied reformulations
        """
        reformulations = []
        has_reformulations = False
        postprocessing_instructions = []
        
        lines = latex_model.split('\n')
        
        for i, line in enumerate(lines):
            line = line.strip()
            
            # Look for reformulation markers
            if 'REFORMULATION APPLIED:' in line:
                has_reformulations = True
                
                # Extract reformulation type and technique
                if 'BILINEAR REFORMULATION APPLIED:' in line:
                    technique = line.split('BILINEAR REFORMULATION APPLIED:')[1].strip()
                    pattern = ""
                    if i + 1 < len(lines) and 'Pattern:' in lines[i + 1]:
                        pattern = lines[i + 1].split('Pattern:')[1].strip()
                    reformulations.append({
                        'type': 'bilinear',
                        'technique': technique,
                        'pattern': pattern
                    })
                elif 'MIN_PATTERNS REFORMULATION APPLIED:' in line:
                    technique = line.split('MIN_PATTERNS REFORMULATION APPLIED:')[1].strip()
                    pattern = ""
                    if i + 1 < len(lines) and 'Pattern:' in lines[i + 1]:
                        pattern = lines[i + 1].split('Pattern:')[1].strip()
                    reformulations.append({
                        'type': 'min',
                        'technique': technique,
                        'pattern': pattern
                    })
                elif 'MAX_PATTERNS REFORMULATION APPLIED:' in line:
                    technique = line.split('MAX_PATTERNS REFORMULATION APPLIED:')[1].strip()
                    pattern = ""
                    if i + 1 < len(lines) and 'Pattern:' in lines[i + 1]:
                        pattern = lines[i + 1].split('Pattern:')[1].strip()
                    reformulations.append({
                        'type': 'max',
                        'technique': technique,
                        'pattern': pattern
                    })
                elif 'ABSOLUTE VALUE REFORMULATION APPLIED:' in line:
                    technique = line.split('ABSOLUTE VALUE REFORMULATION APPLIED:')[1].strip()
                    pattern = ""
                    if i + 1 < len(lines) and 'Pattern:' in lines[i + 1]:
                        pattern = lines[i + 1].split('Pattern:')[1].strip()
                    reformulations.append({
                        'type': 'absolute_value',
                        'technique': technique,
                        'pattern': pattern
                    })
                elif 'MONOTONE_TRANSFORMATION REFORMULATION APPLIED:' in line:
                    technique = line.split('MONOTONE_TRANSFORMATION REFORMULATION APPLIED:')[1].strip()
                    pattern = ""
                    postprocessing = ""
                    if i + 1 < len(lines) and 'Pattern:' in lines[i + 1]:
                        pattern = lines[i + 1].split('Pattern:')[1].strip()
                    # Look for post-processing info in the next few lines
                    for j in range(i+1, min(i+6, len(lines))):
                        if 'POST-PROCESSING' in lines[j] or 'post-processing' in lines[j].lower():
                            postprocessing = lines[j].strip()
                            break
                    reformulations.append({
                        'type': 'monotone_transformation',
                        'technique': technique,
                        'pattern': pattern
                    })
                    if postprocessing:
                        postprocessing_instructions.append(postprocessing)
        return {
            'has_reformulations': has_reformulations,
            'reformulations': reformulations,
            'postprocessing_instructions': postprocessing_instructions
        }
    
    def execute_code(self, generated_code: str, problem_id=None):
        """
        Execute the generated Gurobi code and return the execution result.
        
        Args:
            generated_code: The generated Gurobi Python code
            problem_id: Problem identifier for parameter loading
            
        Returns:
            dict: Dictionary containing execution result
        """
        return self._execute_gurobi_code(generated_code, problem_id)
    
    def _execute_gurobi_code(self, code: str, problem_id: str) -> dict:
        """
        Execute the generated Gurobi code and capture the result.
        """
        try:
            # Create a temporary namespace for execution with necessary imports - match optimization executor
            local_namespace = {
                'Model': Model,
                'GRB': GRB,
                'quicksum': quicksum,
                'LinExpr': LinExpr,
                'np': np,
                '__builtins__': __builtins__
            }
            
            # Add all gurobipy imports to namespace (matching optimization executor)
            for name in dir():
                if not name.startswith('_'):
                    try:
                        obj = eval(name)
                        if hasattr(obj, '__module__') and obj.__module__ == 'gurobipy':
                            local_namespace[name] = obj
                    except:
                        pass
            
            # Execute the code
            exec(code, {}, local_namespace)
            
            # Try to get optimization results if model exists
            if 'model' in local_namespace:
                model = local_namespace['model']
                if hasattr(model, 'Status'):
                    status = model.Status
                    status_name = {
                        1: 'loaded',
                        2: 'optimal',
                        3: 'infeasible',
                        4: 'infeasible_or_unbounded',
                        5: 'unbounded',
                        6: 'cutoff',
                        7: 'iteration_limit',
                        8: 'node_limit',
                        9: 'time_limit',
                        10: 'solution_limit',
                        11: 'interrupted',
                        12: 'numeric_difficulty',
                        13: 'suboptimal',
                        14: 'inprogress',
                        15: 'user_obj_limit'
                    }.get(status, 'unknown')
                    
                    return {
                        'success': True,
                        'optimization_status': status_name,
                        'model_status': status,
                        'error_message': None
                    }
            
            return {
                'success': True,
                'optimization_status': 'unknown',
                'error_message': None
            }
            
        except Exception as e:
            return {
                'success': False,
                'error_message': str(e),
                'optimization_status': None
            }
    
    def _is_code_execution_error(self, error_message: str) -> bool:
        """
        Determine if the error is related to code generation issues that require regeneration.
        """
        code_related_errors = [
            'Constraint has no bool value',
            'name \'',
            'is not defined',
            'quicksum',
            'if.*condition',
            'syntax error',
            'indentation error',
            'attributeerror',
            'typeerror',
            'indexerror',
            'keyerror'
        ]
        
        error_lower = error_message.lower()
        for error_pattern in code_related_errors:
            if error_pattern.lower() in error_lower:
                return True
        
        # Non-code related errors (don't require regeneration)
        non_code_errors = [
            'license',
            'gurobi license',
            'no module named',
            'import error',
            'file not found',
            'permission denied',
            'memory error',
            'timeout'
        ]
        
        for error_pattern in non_code_errors:
            if error_pattern.lower() in error_lower:
                return False
        
        return True  # Default to requiring regeneration for unknown errors
    
    def _extract_regeneration_reason_from_error(self, error_message: str) -> str:
        """
        Extract the main reason for regeneration from the actual error message.
        """
        error_lower = error_message.lower()
        
        if 'constraint has no bool value' in error_lower:
            return "Conditional expressions in quicksum not supported by Gurobi"
        elif 'name \'' in error_lower and 'is not defined' in error_lower:
            return "Variable scope and indexing issues"
        elif 'quicksum' in error_lower and 'if' in error_lower:
            return "Conditional quicksum expressions not supported by Gurobi"
        elif 'syntax error' in error_lower:
            return "Python syntax errors in generated code"
        elif 'indentation error' in error_lower:
            return "Code indentation issues"
        elif 'attributeerror' in error_lower:
            return "Object attribute access issues"
        elif 'typeerror' in error_lower:
            return "Type mismatch issues"
        elif 'indexerror' in error_lower:
            return "Array indexing issues"
        else:
            return f"Code execution error: {error_message[:100]}..." 