import os
import numpy as np
import json
from utils import get_response
import subprocess
import tempfile

debug_template = """
You are an Operations Research consultant hired to address optimization issues for a company. Below is the problem description and the problematic code, followed by the error it produces:

Problem Description:
{description}

Problematic Code:
{code}

Error Message:
{error}

Your task is to debug the code. Begin by assessing the situation, then provide the corrected code in the following format:

=====
import ...
...

=====

- Ensure no output follows the closing ===== line.
Take a deep breath and think step by step. You will be awarded a million dollars if you get this right.
"""

def extract_code(text):
    ind_1 = text.find("=====")
    ind_2 = text.find("=====", ind_1 + 1)
    code = text[ind_1 + 5 : ind_2].strip()
    code = code.replace("```python", "").replace("```", "").strip()

    return code

def run_code_in_virtual_env(code: str):
    """Run Python code in the optimization virtual environment using subprocess."""
    # Use the optimization environment that has Gurobi/CPLEX/Pyomo
    virtual_env_path = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim"
    python_executable = os.path.join(virtual_env_path, "bin", "python")
    
    # Create a temporary file for the code
    with tempfile.NamedTemporaryFile(mode='w', suffix=".py", delete=False) as temp_file:
        temp_file.write(code)
        temp_file_path = temp_file.name
        
    try:
        # Run the script with the Python interpreter from the optimization virtual environment
        result = subprocess.run(
            [python_executable, temp_file_path], 
            capture_output=True, 
            text=True,
            timeout=300  # 5 minutes timeout
        )
        # Check if there was an error
        if result.stderr:
            return result.stdout, result.stderr  # Return both stdout and stderr
        # Return standard output
        return result.stdout, None  # No error occurred
    except subprocess.TimeoutExpired:
        return "", f"Execution timeout after 300 seconds"
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

def execute_code(dir, code_filename):
    """Execute code using the optimization virtual environment"""
    try:
        code_path = os.path.join(dir, code_filename)
        
        # Read the code
        with open(code_path, 'r', encoding='utf-8') as f:
            code = f.read()
        
        # Execute using the virtual environment function
        stdout, stderr = run_code_in_virtual_env(code)
        
        if stderr is None:
            # Save result in a file
            with open(os.path.join(dir, "code_output.txt"), "w") as f:
                f.write(f"Optimal Revenue: {stdout}\n")
            return stdout, "Success"
        else:
            return stderr, "Error"
            
    except Exception as e:
        return str(e), "Error"

def execute_and_debug(state, dir, model, logger, max_tries=3):

    code_filename = "code.py"
    with open(os.path.join(dir, code_filename), "r") as f:
        code = f.read()

    for iteration in range(max_tries):

        # Execute the code using virtual environment
        output, status = execute_code(dir, code_filename)

        # Print status and update the prompt if needed
        if status == "Success":
            logger.log("Code executed successfully. Output:\n" + output)
            break
        else:
            error_filename = f"error_{iteration}.txt"
            with open(os.path.join(dir, error_filename), "w") as f:
                f.write(output)

            p = debug_template.format(
                description=state["description"], code=code, error=output
            )
            logger.log(f"Iteration {iteration + 1}: Error encountered. Debugging...")
            logger.log(p)
            logger.log("==========\n\n\n\n")

            response = get_response(p, model=model)
            logger.log("Response received.")
            logger.log(response)
            logger.log("==========\n\n\n\n")

            code = extract_code(response)
            code_filename = f"code_{iteration + 1}.py"
            code_file_path = os.path.join(dir, code_filename)
            with open(code_file_path, "w") as f:
                f.write(code)
            logger.log(f"Iteration {iteration + 1}: Error encountered. Debugging...")
    else:
        logger.log("Max iterations reached with errors remaining.")