#!/usr/bin/env python3

import os
import re
import time
import tempfile
import subprocess
import logging
import json
from pathlib import Path
from config import *
from api_client import call_llama_api

from prompts import (
    load_prompt_template,
    create_solver_code_prompt_with_template,
    create_solver_code_retry_prompt_with_template,
    create_solver_results_analysis_prompt
)

from utils import extract_code_from_response, check_solver_values_consistency

def load_solver_templates():
    """Load solver template files as in-context examples"""
    templates = {}
    template_files = {
        "gurobipy": "solver_templates/gurobipy_template.py",
        "docplex": "solver_templates/docplex_template.py",
        "pyomo": "solver_templates/pyomo_template.py"
    }
    
    for solver_type, filename in template_files.items():
        template_path = os.path.join(TEMPLATES_DIR, filename)
        try:
            with open(template_path, "r", encoding="utf-8") as f:
                templates[solver_type] = f.read()
            logging.info(f"    Template loaded successfully: {filename}")
        except FileNotFoundError:
            logging.warning(f"    WARNING: {filename} not found")
            templates[solver_type] = ""
        except Exception as e:
            logging.error(f"    ERROR loading {filename}: {e}")
            templates[solver_type] = ""
    
    return templates

def run_code_in_virtual_env(code: str):
    """Run Python code in a specified virtual environment using a subprocess."""
    # Retrieving virtual environment location from env variable
    virtual_env_path = os.getenv("OPTIM_VENV_PATH", None)
    
    # Path to the Python executable in the virtual environment
    virtual_env_path = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim"
    python_executable = os.path.join(virtual_env_path, "bin", "python")
    
    # Create a temporary file for the code
    with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
        temp_file.write(code.encode())
        temp_file_path = temp_file.name
        
    try:
        # Run the script with the Python interpreter from the virtual environment
        result = subprocess.run(
            [python_executable, temp_file_path], 
            capture_output=True, 
            text=True,
            timeout=SOLVER_TIMEOUT
        )
        # Check if there was an error
        if result.stderr:
            return result.stdout, result.stderr  # Return both stdout and stderr
        # Print standard output
        return result.stdout, None  # No error occurred
    except subprocess.TimeoutExpired:
        return "", f"Execution timeout after {SOLVER_TIMEOUT} seconds"
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

def execute_solver_code(code, solver_type, db_id):
    """Execute solver code in the optimization environment and capture results"""
    
    logging.info(f"    Executing {solver_type.upper()} code...")
    
    try:
        # Execute using the virtual environment function
        start_time = time.time()
        stdout, stderr = run_code_in_virtual_env(code)
        execution_time = time.time() - start_time
        
        # Parse results
        solver_result = {
            "solver_type": solver_type,
            "execution_time": execution_time,
            "return_code": 0 if stderr is None else 1,
            "stdout": stdout,
            "stderr": stderr if stderr else "",
            "status": "unknown",
            "optimal_value": None,
            "error_message": None,
            "decision_variables": {}
        }
        
        # Analyze output
        if stderr is None:
            # Success - look for optimal value and decision variables
            output_lines = stdout.split('\n')
            for line in output_lines:
                if "Optimal value:" in line:
                    try:
                        # Extract numerical value after "Optimal value:"
                        value_str = line.split("Optimal value:")[-1].strip()
                        solver_result["optimal_value"] = float(value_str)
                        solver_result["status"] = "optimal"
                        logging.info(f"    {solver_type.upper()} solved successfully: {solver_result['optimal_value']}")
                        break
                    except (ValueError, IndexError):
                        continue
                elif "x[" in line and "=" in line:
                    # Extract decision variable values
                    try:
                        var_match = re.search(r'x\[(\d+)\]\s*=\s*([\d\.-]+)', line)
                        if var_match:
                            var_index = var_match.group(1)
                            var_value = float(var_match.group(2))
                            solver_result["decision_variables"][f"x_{var_index}"] = var_value
                    except (ValueError, AttributeError):
                        continue
                elif "infeasible" in line.lower():
                    solver_result["status"] = "infeasible"
                    solver_result["error_message"] = "Problem is infeasible"
                    logging.info(f"    {solver_type.upper()} result: INFEASIBLE")
                    break
                elif "unbounded" in line.lower():
                    solver_result["status"] = "unbounded"
                    solver_result["error_message"] = "Problem is unbounded"
                    logging.info(f"    {solver_type.upper()} result: UNBOUNDED")
                    break
            
            if solver_result["status"] == "unknown":
                solver_result["status"] = "solved_no_value"
                solver_result["error_message"] = "Solver completed but no optimal value found"
                logging.info(f"    {solver_type.upper()} completed but no optimal value detected")
        else:
            # Error occurred - ensure stderr is string before calling .lower()
            stderr_str = str(stderr) if stderr is not None else ""
            if "timeout" in stderr_str.lower():
                solver_result["status"] = "timeout"
                solver_result["error_message"] = f"Execution timeout after {SOLVER_TIMEOUT} seconds"
                logging.info(f"    {solver_type.upper()} execution TIMEOUT after {SOLVER_TIMEOUT} seconds")
            else:
                solver_result["status"] = "error"
                solver_result["error_message"] = stderr_str
                logging.info(f"    {solver_type.upper()} execution failed: {solver_result['error_message'][:100]}...")
        
        return solver_result
        
    except Exception as e:
        logging.error(f"    {solver_type.upper()} execution ERROR: {e}")
        return {
            "solver_type": solver_type,
            "execution_time": 0,
            "return_code": -1,
            "stdout": "",
            "stderr": str(e),
            "status": "execution_error",
            "optimal_value": None,
            "error_message": str(e),
            "decision_variables": {}
        }

def save_solver_execution_log(output_dir, db_id, solver_type, attempt, code, result):
    """Save individual solver execution log to separate files"""
    solver_logs_dir = os.path.join(output_dir, db_id, "solver_logs")
    os.makedirs(solver_logs_dir, exist_ok=True)
    
    # Create log filename with attempt number
    log_filename = f"{solver_type}_attempt_{attempt}.log"
    log_path = os.path.join(solver_logs_dir, log_filename)
    
    # Create detailed log content
    log_content = f"""Solver Execution Log
Solver: {solver_type.upper()}
Attempt: {attempt}
Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}
Database: {db_id}

Generated Code:
{code}

Execution Results:
Status: {result.get('status', 'unknown')}
Return Code: {result.get('return_code', 'unknown')}
Execution Time: {result.get('execution_time', 0):.3f} seconds
Optimal Value: {result.get('optimal_value', 'N/A')}

STDOUT:
{result.get('stdout', '')}

STDERR:
{result.get('stderr', '')}

Error Message:
{result.get('error_message', 'None')}

Decision Variables:
{json.dumps(result.get('decision_variables', {}), indent=2)}
"""
    
    # Write log file
    with open(log_path, 'w', encoding='utf-8') as f:
        f.write(log_content)

def generate_and_execute_solver_codes_with_templates(client, model, problem_description_text, mathematical_solution, db_id):
    """Generate and execute solver codes with template guidance and retry logic for consistency"""
    
    logging.info(f"    Loading solver templates...")
    templates = load_solver_templates()
    
    logging.info(f"    Generating solver codes with template guidance...")
    
    solver_results = {}
    solver_codes = {}
    
    # Initial attempt - Generate codes for all available solvers
    for solver_type in ["gurobipy", "docplex", "pyomo"]:
        try:
            logging.info(f"    Generating {solver_type.upper()} code with template...")
            
            # Get template code for this solver
            template_code = templates.get(solver_type, "")
            
            # Generate solver-specific code with template
            solver_prompt = create_solver_code_prompt_with_template(
                problem_description_text, mathematical_solution, solver_type, db_id, template_code
            )
            
            # Use call_llama_api for code output
            solver_response = call_llama_api(
                client, model, solver_prompt,
                f"{solver_type.upper()} Code Generation with Template - {db_id}"
            )
            
            # Extract code from response
            code = extract_code_from_response(solver_response)
            solver_codes[solver_type] = {
                "raw_response": solver_response,
                "extracted_code": code
            }
            
            if not code:
                logging.warning(f"    WARNING: No code extracted for {solver_type.upper()}")
                solver_results[solver_type] = {
                    "solver_type": solver_type,
                    "status": "code_generation_failed",
                    "error_message": "Failed to extract code from LLM response",
                    "decision_variables": {}
                }
                continue
            
            time.sleep(1)  # Rate limiting
            
            # Execute the generated code
            result = execute_solver_code(code, solver_type, db_id)
            solver_results[solver_type] = result
            
            # Save individual solver log - use OUTPUT_BASE_DIR directly
            save_solver_execution_log(OUTPUT_BASE_DIR, db_id, solver_type, 1, code, result)
            
        except Exception as e:
            logging.error(f"    ERROR generating/executing {solver_type.upper()} code: {e}")
            solver_results[solver_type] = {
                "solver_type": solver_type,
                "status": "generation_error",
                "error_message": str(e),
                "decision_variables": {}
            }
            solver_codes[solver_type] = {
                "raw_response": "",
                "extracted_code": ""
            }
    
    # Check for consistency and retry if needed
    for retry_attempt in range(1, MAX_SOLVER_RETRY_ATTEMPTS + 1):
        logging.info(f"    Checking solver consistency (attempt {retry_attempt})...")
        
        if check_solver_values_consistency(solver_results):
            logging.info(f"    Solver values are consistent and no errors detected, no retry needed")
            break
        
        # Detailed retry reason analysis
        error_solvers = []
        inconsistent_values = []
        error_statuses = ["error", "execution_error", "timeout", "code_generation_failed", "generation_error"]
        
        for solver in ["gurobipy", "docplex", "pyomo"]:
            result = solver_results.get(solver, {})
            status = result.get("status", "unknown")
            
            if status in error_statuses:
                error_solvers.append(f"{solver}({status})")
            elif status == "optimal" and result.get("optimal_value") is not None:
                inconsistent_values.append(f"{solver}({result.get('optimal_value'):.3f})")
        
        retry_reasons = []
        if error_solvers:
            retry_reasons.append(f"Errors: {','.join(error_solvers)}")
        if len(inconsistent_values) > 1:
            # Check if values are actually inconsistent
            values = [float(v.split('(')[1].split(')')[0]) for v in inconsistent_values]
            if len(values) >= 2:
                base_value = values[0]
                if any(abs(value - base_value) > SOLVER_CONSISTENCY_TOLERANCE for value in values[1:]):
                    retry_reasons.append(f"Inconsistent values: {','.join(inconsistent_values)}")
        
        logging.info(f"    Retry needed - {'; '.join(retry_reasons) if retry_reasons else 'Unknown issue'}")
        
        if retry_attempt >= MAX_SOLVER_RETRY_ATTEMPTS:
            logging.info(f"    Maximum retry attempts ({MAX_SOLVER_RETRY_ATTEMPTS}) reached, keeping final results")
            break
        
        logging.info(f"    Starting retry attempt {retry_attempt}...")
        
        # Retry all solvers with context of latest round only (not all history)
        for solver_type in ["gurobipy", "docplex", "pyomo"]:
            try:
                logging.info(f"    Retry {retry_attempt}: Regenerating {solver_type.upper()} code with latest context...")
                
                # Get template code for this solver
                template_code = templates.get(solver_type, "")
                
                # Generate solver-specific code with template and latest round context only
                solver_retry_prompt = create_solver_code_retry_prompt_with_template(
                    problem_description_text, mathematical_solution, solver_type, db_id, 
                    template_code, solver_results, retry_attempt  # Only pass latest round results
                )
                
                # Use call_llama_api for code output
                solver_response = call_llama_api(
                    client, model, solver_retry_prompt,
                    f"{solver_type.upper()} Code Retry {retry_attempt} - {db_id}"
                )
                
                # Extract code from response
                code = extract_code_from_response(solver_response)
                solver_codes[solver_type] = {
                    "raw_response": solver_response,
                    "extracted_code": code,
                    "retry_attempt": retry_attempt
                }
                
                if not code:
                    logging.warning(f"    WARNING: No code extracted for {solver_type.upper()} retry {retry_attempt}")
                    continue
                
                time.sleep(1)  # Rate limiting
                
                # Execute the regenerated code
                result = execute_solver_code(code, solver_type, db_id)
                result["retry_attempt"] = retry_attempt
                solver_results[solver_type] = result
                
                # Save individual solver log for retry - use OUTPUT_BASE_DIR directly
                save_solver_execution_log(OUTPUT_BASE_DIR, db_id, solver_type, retry_attempt + 1, code, result)
                
            except Exception as e:
                logging.error(f"    ERROR in retry {retry_attempt} for {solver_type.upper()}: {e}")
                solver_results[solver_type]["retry_error"] = str(e)
    
    return solver_results, solver_codes