#!/usr/bin/env python3
"""
Simple Zero-Shot Gurobi Code Generator and Solver
"""

import os
import re
import time
import tempfile
import subprocess
from openai import OpenAI

class SimpleZeroShotSolver:
    def __init__(self, model_name="llama3-70b", temperature=0.1, log_file=None):
        self.model_name = model_name
        self.temperature = temperature
        self.log_file = log_file
        
        # Setup RITS API client
        self.setup_api_client()
        
        # Gurobi virtual environment path
        self.gurobi_env_path = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim"
        self.python_executable = os.path.join(self.gurobi_env_path, "bin", "python")
    
    def setup_api_client(self):
        """Setup RITS API client"""
        os.environ['RITS_API_KEY'] = 'RITS_API_PLACEHOLDER'
        api_key = os.environ.get("RITS_API_KEY")
        
        if not api_key:
            raise ValueError("Please set RITS_API_KEY environment variable")
        
        self.client = OpenAI(
            api_key="dummy",
            base_url="API_ENDPOINT_PLACEHOLDER/deepseek-v3-h200/v1",
            default_headers={"RITS_API_KEY": api_key},
            timeout=300
        )
    
    def log(self, message):
        """Log message to file and console"""
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
        log_msg = f"[{timestamp}] {message}"
        print(log_msg)
        
        if self.log_file:
            with open(self.log_file, "a", encoding="utf-8") as f:
                f.write(log_msg + "\n")
    
    def call_llm(self, prompt, max_retries=3):
        """Call LLM API with retry logic"""
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="deepseek-ai/DeepSeek-V3",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=4096,
                    temperature=self.temperature,
                    top_p=0.9,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    stream=False
                )
                
                if response.choices and response.choices[0].message.content:
                    return response.choices[0].message.content
                else:
                    self.log(f"WARNING: Empty response on attempt {attempt + 1}")
                    
            except Exception as e:
                self.log(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    raise e
        
        return "No response generated after multiple attempts."
    
    def generate_zero_shot_prompt(self, problem_description):
        """Generate zero-shot prompt for Gurobi code generation"""
        
        prompt = f"""You are an expert in optimization and Gurobi programming. Given the following optimization problem description, generate complete Python code using Gurobi to solve it.

**Problem Description:**
{problem_description}

**Requirements:**
1. Generate complete, executable Python code using Gurobi (gurobipy)
2. Include all necessary imports
3. Define variables, constraints, and objective function correctly
4. Include model optimization and result printing
5. Print the optimal objective value clearly at the end
6. Handle potential errors gracefully
7. Make sure the code can run independently

**Output Format:**
```python
# Your complete Gurobi Python code here
import gurobipy as gp
from gurobipy import GRB

# Model creation and solving code
# ...

# Print results clearly
print(f"Optimal Objective Value: {{optimal_value}}")
```

**Important Notes:**
- Focus on creating a working solution
- Use appropriate variable types (continuous, integer, binary)
- Include proper constraint formulations
- Ensure the objective function is correctly defined
- Add comments for clarity

Generate the complete Gurobi Python code:"""

        return prompt
    
    def extract_python_code(self, llm_response):
        """Extract Python code from LLM response"""
        # Look for code blocks
        code_blocks = re.findall(r'```python\s*(.*?)\s*```', llm_response, re.DOTALL)
        
        if code_blocks:
            return code_blocks[0].strip()
        
        # Fallback: look for code after "import" statement
        lines = llm_response.split('\n')
        code_lines = []
        in_code = False
        
        for line in lines:
            if 'import' in line and ('gurobi' in line.lower() or 'gurobipy' in line.lower()):
                in_code = True
            
            if in_code:
                code_lines.append(line)
        
        if code_lines:
            return '\n'.join(code_lines).strip()
        
        return None
    
    def execute_gurobi_code(self, python_code):
        """Execute Gurobi code in virtual environment"""
        try:
            # Create temporary file
            with tempfile.NamedTemporaryFile(mode='w', suffix=".py", delete=False) as temp_file:
                temp_file.write(python_code)
                temp_file_path = temp_file.name
            
            # Execute code using Gurobi virtual environment
            result = subprocess.run(
                [self.python_executable, temp_file_path],
                capture_output=True,
                text=True,
                timeout=300  # 5 minutes timeout
            )
            
            # Clean up temporary file
            os.unlink(temp_file_path)
            
            if result.returncode == 0:
                return True, result.stdout
            else:
                return False, result.stderr
        
        except subprocess.TimeoutExpired:
            return False, "Execution timeout after 300 seconds"
        except Exception as e:
            return False, f"Execution error: {str(e)}"
    
    def extract_optimal_value(self, execution_output):
        """Extract optimal value from Gurobi execution output"""
        # Look for "Optimal Objective Value:" pattern
        patterns = [
            r'Optimal Objective Value:\s*([+-]?[\d.,]+)',
            r'Best objective\s*([+-]?[\d.,]+)',
            r'Optimal value:\s*([+-]?[\d.,]+)',
            r'objective\s*([+-]?[\d.,]+)',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, execution_output, re.IGNORECASE)
            if match:
                value_str = match.group(1).replace(',', '')
                try:
                    return float(value_str)
                except ValueError:
                    continue
        
        return None
    
    def solve_problem(self, problem_description):
        """Main solving function"""
        result = {
            'success': False,
            'optimal_value': None,
            'generated_code': None,
            'execution_output': None,
            'error_message': None
        }
        
        try:
            self.log("Starting Simple Zero-Shot optimization...")
            
            # Generate zero-shot prompt
            prompt = self.generate_zero_shot_prompt(problem_description)
            self.log(f"Generated prompt ({len(prompt)} chars)")
            
            # Call LLM to generate code
            self.log("Calling LLM for code generation...")
            llm_response = self.call_llm(prompt)
            self.log(f"Received LLM response ({len(llm_response)} chars)")
            
            # Extract Python code
            python_code = self.extract_python_code(llm_response)
            if not python_code:
                result['error_message'] = "Failed to extract Python code from LLM response"
                self.log(f"ERROR: {result['error_message']}")
                return result
            
            result['generated_code'] = python_code
            self.log(f"Extracted Python code ({len(python_code)} chars)")
            
            # Execute Gurobi code
            self.log("Executing Gurobi code...")
            execution_success, execution_output = self.execute_gurobi_code(python_code)
            result['execution_output'] = execution_output
            
            if execution_success:
                self.log("Code execution successful")
                
                # Extract optimal value
                optimal_value = self.extract_optimal_value(execution_output)
                if optimal_value is not None:
                    result['optimal_value'] = optimal_value
                    result['success'] = True
                    self.log(f"SUCCESS: Optimal value = {optimal_value}")
                else:
                    result['error_message'] = "Could not extract optimal value from output"
                    self.log(f"WARNING: {result['error_message']}")
                    # Still consider it a success if code ran without errors
                    result['success'] = True
            else:
                result['error_message'] = f"Code execution failed: {execution_output}"
                self.log(f"ERROR: {result['error_message']}")
            
        except Exception as e:
            result['error_message'] = f"Unexpected error: {str(e)}"
            self.log(f"ERROR: {result['error_message']}")
        
        return result