from gurobipy import *
import numpy as np
import json
from datetime import datetime
from pathlib import Path
from src.agents.mip_coordinator import MIPCoordinator
from src.core.nonlinear_detector import NonLinearPatternExtractor
from src.core.code_converter import CodeConverter
from src.core.optimization_executor import OptimizationExecutor

# Utility function to extract the LaTeX model from a LinearizeLLM .tex file

def extract_latex_model_from_linearizellm_tex(tex_path):
    """
    Extract the mathematical model section from a LinearizeLLM LaTeX file and load parameters and decisions from the corresponding JSON files.
    Returns a tuple: (latex_model, parameters_dict, decisions_dict)
    
    Expected folder structure:
            LinearizeLLM_data/instances_linearizellm/problem_name/
    ├── problem_name.tex
    ├── parameters.json
    └── decisions.json
    
    Parameter JSON structure:
    - Keys directly correspond to LaTeX parameter names
    - Values indicate dimensions:
      * Number = 1D parameter (scalar)
      * Array = nD parameter (vector) 
      * Nested array = n×n parameter (matrix)
    - "dims" key is a dictionary mapping parameter names to their dimensions
      * e.g., {"p": ["J"], "c": ["I", "J"]} means p is J-dimensional, c is I×J matrix
      
    Decisions JSON structure:
    - Keys are base decision variable names
    - Values are arrays of all forms/notations of that variable in the LaTeX model
      * e.g., {"x": ["x", "x_i", "x_{i,j}"]}
    """
    tex_path = Path(tex_path)
    
    # Find the corresponding JSON files
    # The tex file should be in a subfolder like: instances_linearizellm/problem_name/problem_name.tex
    # The JSON files should be in the same subfolder
    problem_dir = tex_path.parent
    parameters_file = problem_dir / "parameters.json"
    decisions_file = problem_dir / "decisions.json"
    
    # Load parameters from JSON file
    parameters = {}
    if parameters_file.exists():
        try:
            with open(parameters_file, 'r', encoding='utf-8') as f:
                raw_parameters = json.load(f)
            
            # Extract dimensionality information first
            dims_info = raw_parameters.get("dims", {})
            parameters["dims"] = dims_info
            
            # Process parameters based on the new structure
            for key, value in raw_parameters.items():
                if key == "dims":
                    # Already handled above
                    continue
                else:
                    # Determine parameter type and store accordingly
                    if isinstance(value, (int, float)):
                        # 1D parameter (scalar)
                        parameters[key] = value
                    elif isinstance(value, list):
                        if all(isinstance(item, (int, float)) for item in value):
                            # nD parameter (vector)
                            parameters[key] = value
                        elif all(isinstance(item, list) for item in value):
                            # n×n parameter (matrix)
                            parameters[key] = value
                        else:
                            # Mixed or complex structure - store as is
                            parameters[key] = value
                    else:
                        # Other types - store as is
                        parameters[key] = value
            
            print(f"📊 Loaded parameters from: {parameters_file}")
            print(f"   Parameter types: {[(k, type(v).__name__) for k, v in parameters.items() if k != 'dims']}")
            if "dims" in parameters:
                print(f"   Dimensions mapping: {parameters['dims']}")
                
        except Exception as e:
            print(f"⚠️ Failed to load parameters from {parameters_file}: {str(e)}")
            parameters = {}
    else:
        print(f"⚠️ No parameters.json found in {problem_dir}")
    
    # Load decisions from JSON file
    decisions = {}
    if decisions_file.exists():
        try:
            with open(decisions_file, 'r', encoding='utf-8') as f:
                decisions = json.load(f)
            
            print(f"📊 Loaded decisions from: {decisions_file}")
            print(f"   Decision variables: {list(decisions.keys())}")
            for var, forms in decisions.items():
                print(f"   {var}: {forms}")
                
        except Exception as e:
            print(f"⚠️ Failed to load decisions from {decisions_file}: {str(e)}")
            decisions = {}
    else:
        print(f"⚠️ No decisions.json found in {problem_dir}")
    
    # Read the LaTeX file
    with open(tex_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    # Find the start of the compact formulation section
    start_idx = None
    for i, line in enumerate(lines):
        if 'Compact formulation with explicit index ranges' in line:
            start_idx = i
            break
    if start_idx is None:
        raise ValueError('Could not find compact formulation section in LaTeX file.')
    
    # Find the next LaTeX math block (starts with \[)
    model_lines = []
    in_math = False
    for line in lines[start_idx:]:
        if r'\[' in line:
            in_math = True
        if in_math:
            model_lines.append(line)
        if r'\]' in line and in_math:
            break
    
    # Join and clean up
    latex_model = ''.join(model_lines).strip()
    
    return latex_model, parameters, decisions


class LinearizeLLMWorkflow:
    """
    Workflow for processing LinearizeLLM files directly.
    """
    def __init__(self, tex_path, problem_id=None, save_results=True, results_base_dir="data/results", llm_model="o3"):
        """
        Initialize the LinearizeLLM workflow.
        
        Args:
            tex_path: Path to the LaTeX file
            problem_id: Problem identifier
            save_results: Whether to save results to files
            results_base_dir: Base directory for saving results
            llm_model: Either a string (model name) or LLMConfig object
        """
        self.tex_path = tex_path
        self.problem_id = problem_id or Path(tex_path).stem
        self.save_results = save_results
        self.results_base_dir = Path(results_base_dir)
        self.llm_model = llm_model
        self.run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.results_dir = None
        self.results = {}
        self.parameters = {}  # Store extracted parameters
        self.decisions = {}   # Store extracted decisions
        if self.save_results:
            self._setup_results_directory()

    def _setup_results_directory(self):
        self.results_dir = self.results_base_dir / f"problem_{self.problem_id}" / f"run_{self.run_timestamp}"
        self.results_dir.mkdir(parents=True, exist_ok=True)
        (self.results_dir / "models").mkdir(exist_ok=True)
        (self.results_dir / "code").mkdir(exist_ok=True)
        (self.results_dir / "steps").mkdir(exist_ok=True)  # Add missing steps directory
        print(f"📁 Results will be saved to: {self.results_dir}")

    def _save_step_result(self, step_name, content, file_extension="txt"):
        if not self.save_results or not self.results_dir:
            return
        subdir = "models" if step_name in ["latex_model", "linearized_model"] else "code" if step_name == "gurobi_code" else "steps"
        if file_extension == "txt" and step_name in ["latex_model", "linearized_model"]:
            file_extension = "tex"
        if file_extension == "txt" and step_name == "gurobi_code":
            file_extension = "py"
        file_path = self.results_dir / subdir / f"{step_name}.{file_extension}"
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(content)
        print(f"💾 Saved {step_name} to: {file_path}")

    def run(self, verbose=True):
        import json  # Import json for parameter serialization
        
        if verbose:
            print("="*80)
            print("LINEARIZELLM LATEX WORKFLOW")
            print("="*80)
        
        # Step 1: Load and extract LaTeX model, parameters, and decisions
        if verbose:
            print("\n📄 STEP 1: Loading LaTeX model, parameters, and decisions from file...")
        latex_model, parameters, decisions = extract_latex_model_from_linearizellm_tex(self.tex_path)
        self.results['latex_model'] = latex_model
        self.parameters = parameters
        self.decisions = decisions
        
        # Process parameters for better code generation
        if parameters:
            self._save_step_result('extracted_parameters', json.dumps(parameters, indent=2), 'json')
            if verbose:
                print(f"📊 Parameters extracted: {list(parameters.keys())}")
        
        self._save_step_result('latex_model', latex_model, 'tex')
        if verbose:
            print(f"✅ LaTeX model extracted\nPreview:\n{latex_model[:300]}...")
        
        # Step 2: Extract Non-linear Patterns
        if verbose:
            print("\n🔍 STEP 2: Extracting Non-linear Patterns...")
        pattern_extractor = NonLinearPatternExtractor(self.llm_model)
        
        # Format param_info for pattern detection (to be ignored)
        param_info = ""
        parameter_context = ""
        if self.parameters or self.decisions:
            import json
            
            # Create parameter info (parameters to be ignored during detection)
            if self.parameters:
                param_info = f"""
AVAILABLE CONCRETE PARAMETERS (IGNORE THESE WHEN DETECTING NONLINEARITIES):
{json.dumps(self.parameters, indent=2)}

IMPORTANT: These are known parameter values. Only look for nonlinearities in DECISION VARIABLES, not in these parameters.
"""
            
            # Create parameter context with both parameters and decision variables
            param_names = list(self.parameters.keys()) if self.parameters else []
            decision_names = list(self.decisions.keys()) if self.decisions else []
            
            parameter_context = f"""
PARAMETER CONTEXT FOR VALIDATION:

PARAMETERS (known values): {', '.join(param_names) if param_names else 'None'}

DECISION VARIABLES: {', '.join(decision_names) if decision_names else 'None'}
"""
            
            # Add decision variable forms if available
            if self.decisions:
                parameter_context += "\nDECISION VARIABLE FORMS:\n"
                for var, forms in self.decisions.items():
                    parameter_context += f"- {var}: {', '.join(forms)}\n"
        
        # Store the LLM input for pattern detection
        pattern_detection_input = {
            "latex_model": latex_model,
            "parameter_context": parameter_context,
            "param_info": param_info
        }
        self._save_step_result('pattern_detection_llm_input', json.dumps(pattern_detection_input, indent=2), 'json')
        
        # Extract and parse patterns with validation
        parsed_patterns = pattern_extractor.extract_and_parse(latex_model, parameter_context=parameter_context, param_info=param_info)
        
        # Store the raw detection agent output (we need to get this separately for backward compatibility)
        extracted_patterns = pattern_extractor.extract_patterns(latex_model, param_info=param_info)
        self.results['detection_agent_raw_output'] = extracted_patterns
        self._save_step_result('detection_agent_raw_output', extracted_patterns)
        
        # Store parsed patterns for internal use
        self.results['extracted_patterns'] = extracted_patterns
        self._save_step_result('extracted_patterns', extracted_patterns)
        
        if verbose:
            print(f"✅ Non-linear pattern extraction completed")
            if parsed_patterns['has_nonlinearities']:
                print(f"   Found: {len(parsed_patterns['bilinear_patterns'])} bilinear, {len(parsed_patterns['min_patterns'])} min, {len(parsed_patterns['max_patterns'])} max, {len(parsed_patterns['absolute_patterns'])} absolute value, {len(parsed_patterns['quotient_patterns'])} quotient patterns, {len(parsed_patterns['monotone_transformation_patterns'])} monotone transformation patterns")
            else:
                print(f"   No non-linearities detected")
        
        # Step 3: Pattern-based MIP Linearization (if needed)
        if verbose:
            print("\n⚙️ STEP 3: Pattern-based MIP Linearization...")
        if parsed_patterns['has_nonlinearities']:
            coordinator = MIPCoordinator(llm_model=self.llm_model, verbose=True)
            
            # Store the reformulation input
            reformulation_input = {
                "latex_model": latex_model,
                "extracted_patterns": extracted_patterns,
                "problem_id": self.problem_id,
                "parameters": self.parameters,
                "parameter_context": parameter_context
            }
            self._save_step_result('reformulation_llm_input', json.dumps(reformulation_input, indent=2), 'json')
            
            # Store the raw reformulation agent output
            reformulation_summary = coordinator.coordinate_reformulation(
                latex_model, extracted_patterns, self.problem_id, self.parameters, parameter_context
            )
            self.results['reformulation_agent_raw_output'] = reformulation_summary
            self._save_step_result('reformulation_agent_raw_output', reformulation_summary, 'tex')
            
            # Store the processed linearized model
            self.results['linearized_model'] = reformulation_summary
            self._save_step_result('linearized_model', reformulation_summary, 'tex')
            
            # Extract the final model for code generation
            if "FINAL LINEARIZED MODEL WITH REFORMULATION MARKINGS:" in reformulation_summary:
                model_for_code = reformulation_summary.split("FINAL LINEARIZED MODEL WITH REFORMULATION MARKINGS:")[1].strip()
            elif "FINAL LINEARIZED MODEL:" in reformulation_summary:
                model_for_code = reformulation_summary.split("FINAL LINEARIZED MODEL:")[1].strip()
            else:
                model_for_code = reformulation_summary
        else:
            self.results['reformulation_agent_raw_output'] = "No reformulation needed - model is already linear"
            self._save_step_result('reformulation_agent_raw_output', self.results['reformulation_agent_raw_output'])
            self.results['linearized_model'] = "No linearization needed - model is already linear"
            self._save_step_result('linearized_model', self.results['linearized_model'])
            model_for_code = latex_model
            if verbose:
                print(f"✅ No linearization needed - model is already linear")
        
        # Step 4: Code Generation
        if verbose:
            print("\n💻 STEP 4: Generating Gurobi Code...")
        code_converter = CodeConverter(self.llm_model)
        
        # Extract reformulation info for potential regeneration
        reformulation_info = code_converter._extract_reformulation_info(model_for_code)
        
        # Store the code generation input
        code_generation_input = {
            "model_for_code": model_for_code,
            "problem_id": self.problem_id,
            "parameters": self.parameters,
            "parameter_context": parameter_context
        }
        self._save_step_result('code_generation_llm_input', json.dumps(code_generation_input, indent=2), 'json')
        
        # Store the raw code generation agent output
        gurobi_code = code_converter.convert_with_parameters(
            model_for_code,
            self.problem_id,
            self.parameters,  # Pass extracted parameters
            parameter_context=parameter_context
        )
        self.results['code_generation_agent_raw_output'] = gurobi_code
        self._save_step_result('code_generation_agent_raw_output', gurobi_code, 'py')
        
        # Store the processed code
        self.results['gurobi_code'] = gurobi_code
        self._save_step_result('gurobi_code', gurobi_code, 'py')
        if verbose:
            print(f"✅ Gurobi code generated\nPreview:\n{gurobi_code[:300]}...")
        
        # Step 5: Optimization Execution
        if verbose:
            print(f"\n🚀 STEP 5: Executing Optimization...")
        
        optimizer = OptimizationExecutor(save_error_log=True)
        optimization_results = optimizer.execute(
            gurobi_code,
            self.problem_id
        )
        self.results['optimization_results'] = optimization_results
        
        # Step 6: Code Validation and Regeneration (only if optimization failed with runtime errors)
        if not optimization_results.get('success', False):
            # Check if it's a code execution error that needs regeneration
            error_message = optimization_results.get('error', 'Unknown error')
            needs_regeneration = code_converter._is_code_execution_error(error_message)
            
            if needs_regeneration:
                if verbose:
                    print(f"\n🔍 STEP 6: Code Validation and Regeneration...")
                    print(f"❌ Optimization failed with code error: {error_message}")
                
                max_regeneration_attempts = 3
                regeneration_attempt = 0
                
                while regeneration_attempt < max_regeneration_attempts:
                    regeneration_attempt += 1
                    regeneration_reason = code_converter._extract_regeneration_reason_from_error(error_message)
                    
                    if verbose:
                        print(f"🔄 Regenerating code (attempt {regeneration_attempt}/{max_regeneration_attempts})...")

                    # Regenerate code with actual error message
                    error_context = f"""
PREVIOUS OPTIMIZATION FAILED:
Error: {error_message}

REGENERATION REASON:
{regeneration_reason}

Please fix the code to address this specific error. The code must be executable without runtime errors.
"""
                    regenerated_code = code_converter.convert_with_parameters(
                        model_for_code,
                        self.problem_id,
                        self.parameters,
                        additional_context=error_context
                    )

                    # Store the raw regeneration agent output
                    self.results[f'code_regeneration_agent_raw_output_attempt_{regeneration_attempt}'] = regenerated_code
                    self._save_step_result(f'code_regeneration_agent_raw_output_attempt_{regeneration_attempt}', regenerated_code, 'py')

                    self.results['gurobi_code'] = regenerated_code
                    self._save_step_result('gurobi_code', regenerated_code, 'py')

                    if verbose:
                        print(f"✅ Code regenerated\nPreview:\n{gurobi_code[:300]}...")
                    
                    # Try optimization again with regenerated code
                    if verbose:
                        print(f"🔄 Retrying optimization with regenerated code...")
                    
                    optimization_results = optimizer.execute(
                        gurobi_code,
                        self.problem_id
                    )
                    self.results['optimization_results'] = optimization_results
                    
                    # Check if it worked this time
                    if optimization_results.get('success', False):
                        if verbose:
                            print(f"✅ Optimization succeeded with regenerated code!")
                        break
                    else:
                        # Still failed, get new error for next iteration
                        error_message = optimization_results.get('error', 'Unknown error')
                        needs_regeneration = code_converter._is_code_execution_error(error_message)
                        
                        if not needs_regeneration:
                            # Non-code error, stop regenerating
                            if verbose:
                                print(f"⚠️ Non-code error detected: {error_message}")
                            break
                
                if regeneration_attempt >= max_regeneration_attempts:
                    if verbose:
                        print(f"⚠️ Maximum regeneration attempts ({max_regeneration_attempts}) reached.")
                
                # Save validation info
                self.results['code_validation'] = f"Code regeneration performed {regeneration_attempt} times due to execution errors. Final error: {error_message}"
            else:
                # Non-code related error (e.g., Gurobi license, system issues)
                self.results['code_validation'] = f"Non-code optimization error: {error_message}"
                if verbose:
                    print(f"⚠️ Non-code optimization error: {error_message}")
        else:
            # Optimization succeeded
            self.results['code_validation'] = "Code executed successfully - no validation needed"
            if verbose:
                print(f"✅ Optimization completed successfully")
        
        # Save optimization results as JSON
        if self.save_results and self.results_dir:
            import json
            opt_results_path = self.results_dir / "optimization_results.json"
            with open(opt_results_path, 'w', encoding='utf-8') as f:
                json.dump(optimization_results, f, indent=2)
            print(f"💾 Saved optimization results to: {opt_results_path}")
            
            # Save comprehensive agent outputs summary
            agent_outputs_summary = {
                "detection_agent_raw_output": self.results.get('detection_agent_raw_output', ''),
                "reformulation_agent_raw_output": self.results.get('reformulation_agent_raw_output', ''),
                "code_generation_agent_raw_output": self.results.get('code_generation_agent_raw_output', ''),
                "code_regeneration_attempts": {
                    key: value for key, value in self.results.items() 
                    if key.startswith('code_regeneration_agent_raw_output_attempt_')
                },
                "optimization_results": optimization_results,
                "metadata": {
                    "problem_id": self.problem_id,
                    "timestamp": datetime.now().isoformat(),
                    "llm_model": self.llm_model
                }
            }
            
            agent_outputs_path = self.results_dir / "agent_outputs_summary.json"
            with open(agent_outputs_path, 'w', encoding='utf-8') as f:
                json.dump(agent_outputs_summary, f, indent=2, ensure_ascii=False)
            print(f"💾 Saved agent outputs summary to: {agent_outputs_path}")
        
        if verbose:
            if optimization_results.get('success'):
                opt_res = optimization_results['optimization_results']
                print(f"✅ Optimization completed successfully")
                print(f"Status: {opt_res['status']}")
                if opt_res['objective_value'] is not None:
                    print(f"Objective Value: {opt_res['objective_value']}")
                print(f"Variables: {len(opt_res['variables'])} variables solved")
            else:
                print(f"❌ Optimization failed: {optimization_results.get('error')}")
        
        return self.results
