#!/usr/bin/env python3
"""
Pattern Reformulation Agents for MIP Linearization.
"""

from src.utils.llm_model_manager import create_llm
from src.utils.prompt_template_manager import PromptTemplateManager
import json


class PatternReformulationAgents:
    """
    Class containing specialized agents for reformulating different types of non-linear patterns.
    """
    
    def __init__(self, llm_model="o3", prompts_dir="src/configs/prompts"):
        """
        Initialize the pattern reformulation agents.
        
        Args:
            llm_model: Either a string (model name) or LLMConfig object
            prompts_dir (str): Directory containing prompt templates
        """
        self.llm_model = llm_model
        self.llm = create_llm(llm_model)
        self.template_manager = PromptTemplateManager(prompts_dir)
    
    def bilinear_pattern_reformulation_agent(self, latex_model: str, bilinear_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating bilinear patterns.
        Handles all instances of a bilinear pattern at once (e.g., all w_i * w_j terms).
        
        Args:
            latex_model: The complete LaTeX model
            bilinear_pattern: The pattern description (e.g., "w_i \cdot w_j (appears in summation over assets)")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for variable bounds and parameter definitions.
"""
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "bilinear_pattern_prompt",
            bilinear_pattern=bilinear_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    
    def min_pattern_reformulation_agent(self, latex_model: str, min_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating min patterns.
        Handles all instances of a min pattern at once (e.g., all min(x_i, capacity_i) terms).
        
        Args:
            latex_model: The complete LaTeX model
            min_pattern: The pattern description (e.g., "min(x_i, capacity_i) (appears for each facility i)")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for constraint definitions.
"""
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "min_pattern_prompt",
            min_pattern=min_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    
    def max_pattern_reformulation_agent(self, latex_model: str, max_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating max patterns.
        Handles all instances of a max pattern at once (e.g., all max(x_i, threshold_i) terms).
        
        Args:
            latex_model: The complete LaTeX model
            max_pattern: The pattern description (e.g., "max(x_i, threshold_i) (appears for each facility i)")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for constraint definitions.
"""
        
        # Log the input for debugging
        agent_input = {
            "agent_type": "max_pattern_reformulation_agent",
            "max_pattern": max_pattern,
            "param_info": param_info,
            "parameter_context": parameter_context,
            "latex_model": latex_model
        }
        print(f"🔍 MAX AGENT INPUT:")
        print(f"   Pattern: {max_pattern}")
        print(f"   Parameter Context: {parameter_context}")
        print(f"   Param Info: {param_info[:200]}..." if len(param_info) > 200 else f"   Param Info: {param_info}")
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "max_pattern_prompt",
            max_pattern=max_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    
    def absolute_value_pattern_reformulation_agent(self, latex_model: str, absolute_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating absolute value patterns.
        Handles all instances of an absolute value pattern at once (e.g., all |x_i - y_i| terms).
        
        Args:
            latex_model: The complete LaTeX model
            absolute_pattern: The pattern description (e.g., "|x_i - y_i| (appears for each variable i)")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for constraint definitions.
"""
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "absolute_value_pattern_prompt",
            absolute_pattern=absolute_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    
    def quotient_pattern_reformulation_agent(self, latex_model: str, quotient_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating quotient patterns.
        Handles all instances of a quotient pattern at once (e.g., all y_i / x_i terms).
        
        Args:
            latex_model: The complete LaTeX model
            quotient_pattern: The pattern description (e.g., "y_i / x_i (appears for each variable i)")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for variable bounds and parameter definitions.
"""
        
        # Log the input for debugging
        agent_input = {
            "agent_type": "quotient_pattern_reformulation_agent",
            "quotient_pattern": quotient_pattern,
            "param_info": param_info,
            "parameter_context": parameter_context,
            "latex_model": latex_model
        }
        print(f"🔍 QUOTIENT AGENT INPUT:")
        print(f"   Pattern: {quotient_pattern}")
        print(f"   Parameter Context: {parameter_context}")
        print(f"   Param Info: {param_info[:200]}..." if len(param_info) > 200 else f"   Param Info: {param_info}")
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "quotient_pattern_prompt",
            quotient_pattern=quotient_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    
    def monotone_transformation_pattern_reformulation_agent(self, latex_model: str, monotone_pattern: str, problem_id=None, parameters=None, parameter_context=""):
        """
        Specialized agent for reformulating monotone transformation patterns in objective functions.
        Handles objective functions of form min f(g(x)) where g(x) is linear and f is monotone.
        
        Args:
            latex_model: The complete LaTeX model
            monotone_pattern: The pattern description (e.g., "min log(sum_i x_i) in objective function")
            problem_id: Problem identifier for parameter loading
            parameters: Pre-loaded parameters
        """
        # Use provided parameters or load them if needed
        param_info = ""
        if parameters:
            param_info = f"""
AVAILABLE CONCRETE PARAMETERS:
{json.dumps(parameters, indent=2)}

Use these concrete values for constraint definitions.
"""
        
        # Format the prompt template
        prompt = self.template_manager.format_template(
            "monotone_transformation_prompt",
            monotone_pattern=monotone_pattern,
            param_info=param_info,
            parameter_context=parameter_context,
            latex_model=latex_model
        )
        
        response = self.llm.invoke(prompt)
        return response.content
    