#!/usr/bin/env python3
"""
Integrated Two-Stage Optimization Solver with Parallel Attempts and Majority Vote
Enhanced with multiple API keys load balancing and multi-solver support
"""

import os
import sys
import json
import time
import tempfile
import subprocess
import sqlite3
import pandas as pd
import argparse
import re
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import threading
from collections import Counter
from pathlib import Path
import random
import hashlib
from typing import List, Dict, Tuple, Optional
from openai import OpenAI

# Configuration - RITS API
RITS_API_KEY = "RITS_API_PLACEHOLDER"
RITS_CONFIG = {
    "base_url": "API_ENDPOINT_PLACEHOLDER/deepseek-v3-h200/v1",
    "model_name": "deepseek-ai/DeepSeek-V3"
}

# Alternative models (commented out, kept for reference)
# RITS_CONFIG = {
#     "base_url": "API_ENDPOINT_PLACEHOLDER/llama-3-3-70b-instruct/v1",
#     "model_name": "meta-llama/Llama-3.3-70B-Instruct"
# }
# RITS_CONFIG = {
#     "base_url": "API_ENDPOINT_PLACEHOLDER/microsoft-phi-4/v1",
#     "model_name": "microsoft/phi-4"
# }

DEFAULT_MODEL = RITS_CONFIG["model_name"]
GUROBI_ENV_PATH = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim"

# Solver types
SOLVERS = ['gurobipy', 'docplex', 'pyomo']

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

class IntegratedOptimizationSolver:
    """Main solver class that integrates both stages with parallel attempts and majority vote"""
    
    def __init__(self, model_name=None, num_parallel_attempts=3,
                 sql_temp_base=0.1, sql_temp_increment=0.3,
                 code_temp_base=0.1, code_temp_increment=0.3):
        """
        Initialize solver with configurable temperature parameters
        
        Args:
            model_name: LLM model to use (if None, uses default from config)
            num_parallel_attempts: Number of parallel attempts per problem
            sql_temp_base: Base temperature for SQL generation (Stage 1)
            sql_temp_increment: Temperature increment for each SQL attempt
            code_temp_base: Base temperature for code generation (Stage 2)
            code_temp_increment: Temperature increment for each code attempt
        """
        self.model_name = model_name  # Will be set per API call based on config
        self.num_parallel_attempts = num_parallel_attempts
        self.sql_temp_base = sql_temp_base
        self.sql_temp_increment = sql_temp_increment
        self.code_temp_base = code_temp_base
        self.code_temp_increment = code_temp_increment
        self.python_executable = os.path.join(GUROBI_ENV_PATH, "bin", "python")
        
    def _setup_rits_client(self) -> Tuple[OpenAI, str]:
        """Setup RITS API client with fixed DeepSeek-V3 configuration"""
        client = OpenAI(
            api_key="dummy",  # RITS uses header authentication
            base_url=RITS_CONFIG["base_url"],
            default_headers={"RITS_API_KEY": RITS_API_KEY},
            timeout=300
        )
        
        return client, RITS_CONFIG["model_name"]
    
    def _get_response(self, prompt: str, temperature: float = 0.1, max_retries: int = 3) -> str:
        """Call RITS API with retry logic and load balancing"""
        for attempt in range(max_retries):
            try:
                # Get a new client for each attempt (with different model endpoint)
                client, model_name = self._setup_rits_client()
                
                response = client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=4096,
                    temperature=temperature,
                    top_p=0.9,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    stream=False
                )
                
                if response.choices and response.choices[0].message.content:
                    return response.choices[0].message.content
                else:
                    print(f"WARNING: Empty response on attempt {attempt + 1}")
                    
            except Exception as e:
                print(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    print(f"Retrying in {2 ** attempt} seconds...")
                    time.sleep(2 ** attempt)
                else:
                    raise e
        
        return "No response generated after multiple attempts."
    
    def _calculate_temperature(self, attempt_num: int, base_temp: float, increment: float) -> float:
        """Calculate temperature for a given attempt number"""
        return base_temp + (attempt_num * increment)
    
    def _get_solver_for_attempt(self, attempt_num: int) -> str:
        """Distribute solvers evenly across attempts"""
        return SOLVERS[attempt_num % len(SOLVERS)]
    
    def solve_problem_with_parallel_attempts(self, problem_info: Dict, output_dir: Path) -> Dict:
        """Solve problem with multiple parallel attempts and majority vote"""
        database_name = problem_info['database_name']
        print(f"\n{'='*60}")
        print(f"Processing: {database_name} with {self.num_parallel_attempts} parallel attempts")
        print(f"SQL Temperature Range: {self.sql_temp_base:.1f} to {self._calculate_temperature(self.num_parallel_attempts-1, self.sql_temp_base, self.sql_temp_increment):.1f}")
        print(f"Code Temperature Range: {self.code_temp_base:.1f} to {self._calculate_temperature(self.num_parallel_attempts-1, self.code_temp_base, self.code_temp_increment):.1f}")
        print(f"Solvers: {', '.join([self._get_solver_for_attempt(i) for i in range(self.num_parallel_attempts)])}")
        print(f"{'='*60}")
        
        # Create main output directory
        main_output_dir = output_dir / database_name
        main_output_dir.mkdir(parents=True, exist_ok=True)
        
        # Check if already processed
        final_output_file = main_output_dir / "code_output.txt"
        if final_output_file.exists():
            print(f"Already processed: {database_name}")
            return {
                'database_name': database_name,
                'status': 'skipped',
                'reason': 'Already processed'
            }
        
        # Run parallel attempts
        all_attempts = []
        with ThreadPoolExecutor(max_workers=self.num_parallel_attempts) as executor:
            futures = []
            
            for attempt_num in range(self.num_parallel_attempts):
                # Create attempt-specific directory
                attempt_dir = main_output_dir / f"attempt_{attempt_num + 1}"
                attempt_dir.mkdir(parents=True, exist_ok=True)
                
                # Submit parallel attempt
                future = executor.submit(
                    self._single_attempt_pipeline,
                    problem_info,
                    attempt_dir,
                    attempt_num
                )
                futures.append((future, attempt_num))
            
            # Collect results
            for future, attempt_num in futures:
                try:
                    attempt_result = future.result()
                    attempt_result['attempt_number'] = attempt_num + 1
                    all_attempts.append(attempt_result)
                    solver_used = attempt_result.get('solver_used', 'unknown')
                    print(f"  Attempt {attempt_num + 1} ({solver_used}): {'SUCCESS' if attempt_result.get('success') else 'FAILED'}")
                except Exception as e:
                    print(f"  Attempt {attempt_num + 1}: EXCEPTION - {str(e)}")
                    all_attempts.append({
                        'attempt_number': attempt_num + 1,
                        'success': False,
                        'error': str(e)
                    })
        
        # Perform majority vote
        final_result = self._majority_vote_selection(all_attempts, main_output_dir)
        
        # Save final result
        with open(final_output_file, 'w') as f:
            if final_result['success']:
                f.write(f"Optimal Objective Value: {final_result['optimal_value']}\n")
                f.write(f"\nVoting Results:\n")
                f.write(f"Total Attempts: {len(all_attempts)}\n")
                f.write(f"Successful Attempts: {final_result['successful_attempts']}\n")
                f.write(f"Consensus Method: {final_result['consensus_method']}\n")
                f.write(f"Confidence: {final_result['confidence']:.1%}\n")
                f.write(f"Winning Solver: {final_result.get('winning_solver', 'N/A')}\n")
                
                if final_result.get('best_execution_output'):
                    f.write("\n" + "="*50 + "\n")
                    f.write("Full Solver Output from Best Attempt:\n")
                    f.write("="*50 + "\n")
                    f.write(final_result['best_execution_output'])
            else:
                f.write(f"ERROR: Optimization failed across all {len(all_attempts)} attempts\n")
                f.write(f"Reason: {final_result.get('error_message', 'Unknown error')}\n")
        
        # Save temperature configuration
        temp_config = {
            'sql_temperatures': [self._calculate_temperature(i, self.sql_temp_base, self.sql_temp_increment) 
                               for i in range(self.num_parallel_attempts)],
            'code_temperatures': [self._calculate_temperature(i, self.code_temp_base, self.code_temp_increment) 
                                for i in range(self.num_parallel_attempts)],
            'solvers_used': [self._get_solver_for_attempt(i) for i in range(self.num_parallel_attempts)]
        }
        
        # Save summary
        summary = {
            'database_name': database_name,
            'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
            'num_attempts': len(all_attempts),
            'temperature_config': temp_config,
            'attempts_details': all_attempts,
            'final_result': final_result
        }
        
        with open(main_output_dir / "summary.json", 'w') as f:
            json.dump(summary, f, indent=2, default=str)
        
        return summary
    
    def _single_attempt_pipeline(self, problem_info: Dict, attempt_dir: Path, attempt_num: int) -> Dict:
        """Execute a single attempt of the full pipeline"""
        solver_type = self._get_solver_for_attempt(attempt_num)
        result = {
            'success': False,
            'optimal_value': None,
            'attempt_dir': str(attempt_dir),
            'stage1_success': False,
            'stage2_success': False,
            'sql_temperature': self._calculate_temperature(attempt_num, self.sql_temp_base, self.sql_temp_increment),
            'code_temperature': self._calculate_temperature(attempt_num, self.code_temp_base, self.code_temp_increment),
            'solver_used': solver_type
        }
        
        try:
            # Stage 1: SQL Data Retrieval with diversity
            sql_temp = result['sql_temperature']
            print(f"\n[Attempt {attempt_num + 1}] Stage 1: SQL Data Retrieval (temp={sql_temp:.1f})")
            
            enhanced_problem_path = self._stage1_sql_retrieval(
                problem_info, 
                attempt_dir, 
                sql_temp
            )
            
            if enhanced_problem_path:
                result['stage1_success'] = True
                result['enhanced_problem_path'] = str(enhanced_problem_path)
                
                # Stage 2: Zero-Shot Code Generation with specific solver
                code_temp = result['code_temperature']
                print(f"[Attempt {attempt_num + 1}] Stage 2: Code Generation ({solver_type}, temp={code_temp:.1f})")
                
                code_result = self._stage2_code_generation(
                    enhanced_problem_path,
                    attempt_dir,
                    code_temp,
                    solver_type
                )
                
                if code_result['success']:
                    result['stage2_success'] = True
                    result['success'] = True
                    result['optimal_value'] = code_result['optimal_value']
                    result['execution_output'] = code_result.get('execution_output', '')
                    result['generated_code'] = code_result.get('generated_code', '')
                else:
                    result['stage2_error'] = code_result.get('error_message', 'Unknown error')
            else:
                result['stage1_error'] = 'Failed to generate enhanced problem description'
                
        except Exception as e:
            result['error'] = str(e)
            print(f"[Attempt {attempt_num + 1}] ERROR: {str(e)}")
        
        return result
    
    def _stage1_sql_retrieval(self, problem_info: Dict, output_dir: Path, temperature: float) -> Optional[Path]:
        """Stage 1: SQL data retrieval - preserving original logic"""
        try:
            # Extract problem context without stored values
            problem_context = self._extract_problem_context_without_stored_values(
                problem_info['problem_path']
            )
            
            # Read schema
            with open(problem_info['schema_sql_path'], 'r', encoding='utf-8') as f:
                schema_sql = f.read()
            
            # Generate SQL queries with LLM - using original prompt
            llm_response = self._generate_sql_queries_with_llm(
                problem_context, 
                schema_sql, 
                temperature
            )
            
            if not llm_response:
                return None
            
            # Save LLM response
            with open(output_dir / "stage1_llm_sql_generation.txt", "w", encoding="utf-8") as f:
                f.write(f"=== Problem Context ===\n{problem_context}\n\n")
                f.write(f"=== Schema ===\n{schema_sql}\n\n")
                f.write(f"=== Temperature ===\n{temperature}\n\n")
                f.write(f"=== LLM Response ===\n{llm_response}\n")
            
            # Extract queries
            queries = self._extract_sql_queries_from_response(llm_response)
            print(f"    Generated {len(queries)} SQL queries")
            
            if not queries:
                return None
            
            # Create database and execute queries
            conn = self._create_database_from_schema(
                problem_info['schema_sql_path'],
                problem_info['data_sql_path']
            )
            
            if not conn:
                return None
            
            query_results = self._execute_sql_queries(conn, queries)
            conn.close()
            
            # Create enhanced problem description
            enhanced_content = self._create_enhanced_problem_description(
                problem_info['problem_path'],
                query_results
            )
            
            # Save enhanced problem description
            enhanced_path = output_dir / "enhanced_problem_description.md"
            with open(enhanced_path, "w", encoding="utf-8") as f:
                f.write(enhanced_content)
            
            # Save stage 1 results
            stage1_results = {
                "queries_generated": len(queries),
                "queries_executed": len(query_results),
                "temperature": temperature,
                "query_results": [
                    {
                        "comment": r['comment'],
                        "query": r['query'],
                        "rows_returned": len(r['result_df']) if 'result_df' in r else 0,
                        "error": r.get('error')
                    }
                    for r in query_results
                ]
            }
            
            with open(output_dir / "stage1_results.json", "w", encoding="utf-8") as f:
                json.dump(stage1_results, f, indent=2)
            
            return enhanced_path
            
        except Exception as e:
            print(f"    Stage 1 Error: {e}")
            return None
    
    def _extract_problem_context_without_stored_values(self, problem_description_path: str) -> str:
        """Extract problem description without 'Current Stored Values' section - original function"""
        with open(problem_description_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        lines = content.split('\n')
        filtered_lines = []
        skip_section = False
        
        for line in lines:
            if '### Current Stored Values' in line or '### current stored values' in line.lower():
                skip_section = True
                continue
            
            if skip_section and line.startswith('### ') and 'current stored values' not in line.lower():
                skip_section = False
            
            if not skip_section:
                filtered_lines.append(line)
        
        return '\n'.join(filtered_lines)
    
    def _generate_sql_queries_with_llm(self, problem_context: str, schema_sql: str, temperature: float) -> str:
        """Generate SQL queries with LLM - preserving original prompt"""
        prompt = f"""You are an expert database analyst helping with optimization problem data retrieval. Based on the problem description and database schema, analyze what data would be most useful for solving this optimization problem and generate appropriate SQL SELECT queries.

**Problem Context (without stored values):**
{problem_context}

**Database Schema:**
{schema_sql}

**Your Task:** 
Carefully analyze the optimization problem and determine what data from the database would be most relevant. Then generate SQL SELECT queries to retrieve this data.

**Analysis Guidelines:**
- Identify what data is needed for decision variables (what needs to be optimized)
- Identify what data is needed for objective function coefficients (what to maximize/minimize)
- Identify what data is needed for constraint parameters (limitations and requirements)
- Consider what aggregated data, summary statistics, or lookup information might be helpful
- Think about relationships between tables and potential joins
- Consider filtering criteria that would make the data more relevant

**Query Requirements:**
- Generate as many queries as you think are necessary and useful (no fixed number)
- Each query should have a clear purpose for the optimization problem
- Include meaningful comments explaining what each query retrieves and why it's relevant
- Use proper SQL syntax with table names, column names, and joins as needed
- Consider both detailed data and summary/aggregated data where appropriate

**Output Format:**
```sql
-- Query Description: Explain what this retrieves and why it's important for optimization
SELECT ... FROM ... WHERE ... ;

-- Query Description: Explain what this retrieves and why it's important for optimization  
SELECT ... FROM ... WHERE ... ;

-- Continue with additional queries as needed
```

Analyze the problem and generate the most relevant SQL queries:"""
        
        try:
            response = self._get_response(prompt, temperature)
            return response
        except Exception as e:
            print(f"    ERROR generating SQL queries: {e}")
            return None
    
    def _extract_sql_queries_from_response(self, llm_response: str) -> List[Dict]:
        """Extract SQL queries from LLM response - original function"""
        queries = []
        
        sql_blocks = re.findall(r'```sql\s*(.*?)\s*```', llm_response, re.DOTALL)
        
        if sql_blocks:
            sql_content = sql_blocks[0]
        else:
            sql_content = llm_response
        
        lines = sql_content.split('\n')
        current_query = []
        current_comment = ""
        
        for line in lines:
            line = line.strip()
            
            if line.startswith('--'):
                if current_query:
                    query_text = ' '.join(current_query).strip()
                    if query_text.upper().startswith('SELECT') and query_text.endswith(';'):
                        queries.append({
                            'comment': current_comment,
                            'query': query_text
                        })
                    current_query = []
                
                current_comment = line[2:].strip()
            
            elif line.upper().startswith('SELECT'):
                current_query = [line]
            
            elif current_query and line:
                current_query.append(line)
                
                if line.endswith(';'):
                    query_text = ' '.join(current_query).strip()
                    queries.append({
                        'comment': current_comment,
                        'query': query_text
                    })
                    current_query = []
        
        if current_query:
            query_text = ' '.join(current_query).strip()
            if query_text.upper().startswith('SELECT'):
                if not query_text.endswith(';'):
                    query_text += ';'
                queries.append({
                    'comment': current_comment,
                    'query': query_text
                })
        
        return queries
    
    def _create_database_from_schema(self, schema_sql_path: str, data_sql_path: str) -> Optional[sqlite3.Connection]:
        """Create in-memory SQLite database - original function"""
        with open(schema_sql_path, 'r', encoding='utf-8') as f:
            schema_sql = f.read()
        
        with open(data_sql_path, 'r', encoding='utf-8') as f:
            data_sql = f.read()
        
        conn = sqlite3.connect(':memory:')
        cursor = conn.cursor()
        
        try:
            schema_statements = [stmt.strip() for stmt in schema_sql.split(';') if stmt.strip()]
            for stmt in schema_statements:
                if stmt.strip():
                    cursor.execute(stmt)
            
            data_statements = [stmt.strip() for stmt in data_sql.split(';') if stmt.strip()]
            for stmt in data_statements:
                if stmt.strip():
                    cursor.execute(stmt)
            
            conn.commit()
            return conn
            
        except Exception as e:
            print(f"    ERROR creating database: {e}")
            conn.close()
            return None
    
    def _execute_sql_queries(self, conn: sqlite3.Connection, queries: List[Dict]) -> List[Dict]:
        """Execute SQL queries - original function"""
        results = []
        
        for i, query_info in enumerate(queries):
            comment = query_info['comment']
            query = query_info['query']
            
            try:
                df = pd.read_sql_query(query, conn)
                
                results.append({
                    'comment': comment,
                    'query': query,
                    'result_df': df,
                    'result_csv': df.to_csv(index=False)
                })
                
                print(f"    Query {i+1}: {comment} -> {len(df)} rows")
                
            except Exception as e:
                print(f"    ERROR executing query {i+1}: {e}")
                print(f"    Query: {query}")
                
                results.append({
                    'comment': comment,
                    'query': query,
                    'result_df': pd.DataFrame(),
                    'result_csv': "",
                    'error': str(e)
                })
        
        return results
    
    def _create_enhanced_problem_description(self, original_problem_path: str, 
                                           query_results: List[Dict]) -> str:
        """Create enhanced problem description - original function"""
        with open(original_problem_path, 'r', encoding='utf-8') as f:
            original_content = f.read()
        
        enhanced_content = self._extract_problem_context_without_stored_values(original_problem_path)
        
        enhanced_content += "\n\n### Retrieved Values\n\n"
        
        for i, result in enumerate(query_results):
            comment = result['comment']
            query = result['query']
            csv_data = result['result_csv']
            
            enhanced_content += f"**Query {i+1}: {comment}**\n\n"
            enhanced_content += f"```sql\n{query}\n```\n\n"
            
            if csv_data and not result.get('error'):
                enhanced_content += f"**Results (CSV format):**\n```csv\n{csv_data}```\n\n"
            else:
                error_msg = result.get('error', 'No data returned')
                enhanced_content += f"**Error:** {error_msg}\n\n"
        
        return enhanced_content
    
    def _stage2_code_generation(self, enhanced_problem_path: Path, output_dir: Path, 
                               temperature: float, solver_type: str) -> Dict:
        """Stage 2: Code generation with specific solver"""
        result = {
            'success': False,
            'optimal_value': None,
            'generated_code': None,
            'execution_output': None,
            'error_message': None,
            'temperature': temperature,
            'solver_type': solver_type
        }
        
        try:
            # Read enhanced problem description
            with open(enhanced_problem_path, 'r', encoding='utf-8') as f:
                enhanced_problem_text = f.read()
            
            # Generate solver-specific prompt
            prompt = self._generate_solver_specific_prompt(enhanced_problem_text, solver_type)
            
            # Call LLM
            llm_response = self._get_response(prompt, temperature)
            
            # Extract Python code
            python_code = self._extract_python_code(llm_response)
            if not python_code:
                result['error_message'] = "Failed to extract Python code from LLM response"
                return result
            
            result['generated_code'] = python_code
            
            # Save generated code
            with open(output_dir / f"generated_{solver_type}_code.py", "w", encoding="utf-8") as f:
                f.write(python_code)
            
            # Execute solver code
            execution_success, execution_output = self._execute_solver_code(python_code, solver_type)
            result['execution_output'] = execution_output
            
            # Save execution output
            with open(output_dir / f"{solver_type}_execution_output.txt", "w", encoding="utf-8") as f:
                f.write(execution_output)
            
            if execution_success:
                optimal_value = self._extract_optimal_value(execution_output)
                if optimal_value is not None:
                    result['optimal_value'] = optimal_value
                    result['success'] = True
                    print(f"    Successfully extracted optimal value: {optimal_value}")
                else:
                    result['error_message'] = "Could not extract optimal value from output"
                    print(f"    WARNING: Failed to extract optimal value from {solver_type} output")
                    # Save problematic output for debugging
                    with open(output_dir / f"{solver_type}_extraction_failed.txt", "w") as f:
                        f.write("Failed to extract optimal value from:\n\n")
                        f.write(execution_output)
            else:
                result['error_message'] = f"Code execution failed: {execution_output}"
            
            # Save stage 2 results
            stage2_results = {
                "success": result['success'],
                "optimal_value": result['optimal_value'],
                "temperature": temperature,
                "solver_type": solver_type,
                "error_message": result.get('error_message')
            }
            
            with open(output_dir / "stage2_zero_shot_results.json", "w", encoding="utf-8") as f:
                json.dump(stage2_results, f, indent=2)
            
        except Exception as e:
            result['error_message'] = f"Unexpected error: {str(e)}"
        
        return result
    
    def _generate_solver_specific_prompt(self, problem_description: str, solver_type: str) -> str:
        """Generate solver-specific prompt with concise template info"""
        
        if solver_type == 'gurobipy':
            solver_info = """**Using Gurobipy:**
- Import: `import gurobipy as gp; from gurobipy import GRB`
- Model: `model = gp.Model("name")`
- Variables: `x = model.addVar(vtype=GRB.CONTINUOUS, name="x", lb=0)`
- Objective: `model.setObjective(gp.quicksum(...), GRB.MINIMIZE)`
- Constraints: `model.addConstr(expr <= rhs, name="c1")`
- Solve: `model.optimize()`
- Result: `if model.status == GRB.OPTIMAL: print(f"Optimal value: {model.objVal}")`
- CRITICAL: Use gp.quicksum() not sum(), validate array lengths"""
            
        elif solver_type == 'docplex':
            solver_info = """**Using DOCplex:**
- Import: `from docplex.mp.model import Model`
- Model: `mdl = Model(name="name")`
- Variables: `x = {i: mdl.continuous_var(name=f"x_{i}", lb=0) for i in range(n)}`
- Objective: `mdl.minimize(mdl.sum(...))`
- Constraints: `mdl.add_constraint(expr <= rhs, ctname="c1")`
- Solve: `solution = mdl.solve()`
- Result: `if solution: print(f"Optimal value: {solution.objective_value}")`
- CRITICAL: Use mdl.sum() not sum(), use safe_range for array indexing"""
            
        elif solver_type == 'pyomo':
            solver_info = """**Using Pyomo with Gurobi:**
- Import: `import pyomo.environ as pyo; from pyomo.opt import SolverFactory`
- Model: `model = pyo.ConcreteModel()`
- Sets: `model.I = pyo.RangeSet(1, n)` (1-based)
- Variables: `model.x = pyo.Var(model.I, within=pyo.NonNegativeReals)`
- Objective: `model.objective = pyo.Objective(rule=obj_rule, sense=pyo.minimize)`
- Constraints: `model.constraint = pyo.Constraint(rule=constraint_rule)`
- Solve: `solver = SolverFactory('gurobi'); results = solver.solve(model)`
- Result: `if results.solver.termination_condition == pyo.TerminationCondition.optimal: print(f"Optimal value: {pyo.value(model.objective)}")`
- CRITICAL: Use rule functions, 1-based indexing, pyo.value() for extraction"""
        
        prompt = f"""You are an expert in optimization programming. Given the following optimization problem description, generate complete Python code using {solver_type} to solve it.

**Problem Description:**
{problem_description}

{solver_info}

**Requirements:**
1. Generate complete, executable Python code
2. Include all necessary imports
3. Define variables, constraints, and objective function correctly
4. Include model optimization and result printing
5. Print the optimal objective value clearly: "Optimal Objective Value: {{value}}"
6. Handle potential errors gracefully
7. Validate array lengths before using indices
8. Make sure the code can run independently

**Output Format:**
```python
# Your complete {solver_type} Python code here
# ... imports and implementation ...
# Must print: Optimal Objective Value: {{value}}
```

Generate the complete {solver_type} Python code:"""

        return prompt
    
    def _extract_python_code(self, llm_response: str) -> Optional[str]:
        """Extract Python code from LLM response - original function"""
        code_blocks = re.findall(r'```python\s*(.*?)\s*```', llm_response, re.DOTALL)
        
        if code_blocks:
            return code_blocks[0].strip()
        
        # Fallback
        lines = llm_response.split('\n')
        code_lines = []
        in_code = False
        
        for line in lines:
            if 'import' in line and any(solver in line.lower() for solver in ['gurobi', 'gurobipy', 'docplex', 'pyomo']):
                in_code = True
            
            if in_code:
                code_lines.append(line)
        
        if code_lines:
            return '\n'.join(code_lines).strip()
        
        return None
    
    def _execute_solver_code(self, python_code: str, solver_type: str) -> Tuple[bool, str]:
        """Execute solver code - handles different solver types"""
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix=".py", delete=False) as temp_file:
                temp_file.write(python_code)
                temp_file_path = temp_file.name
            
            result = subprocess.run(
                [self.python_executable, temp_file_path],
                capture_output=True,
                text=True,
                timeout=300
            )
            
            os.unlink(temp_file_path)
            
            if result.returncode == 0:
                return True, result.stdout
            else:
                return False, result.stderr
                
        except subprocess.TimeoutExpired:
            return False, "Execution timeout after 300 seconds"
        except Exception as e:
            return False, f"Execution error: {str(e)}"
    
    def _extract_optimal_value(self, execution_output: str) -> Optional[float]:
        """Extract optimal value - enhanced for different solver outputs with LLM fallback"""
        # First try pattern matching
        patterns = [
            r'Optimal Objective Value:\s*([+-]?[\d.,]+)',
            r'Best objective\s*([+-]?[\d.,]+)',
            r'Optimal value:\s*([+-]?[\d.,]+)',
            r'objective\s*([+-]?[\d.,]+)',
            r'Objective value:\s*([+-]?[\d.,]+)',
            r'Solution value =\s*([+-]?[\d.,]+)',  # DOCplex
            r'objective value:\s*([+-]?[\d.,]+)',   # Pyomo
            r'Obj:\s*([+-]?[\d.,]+)',              # Gurobi short form
            r'Optimal solution found.*?([+-]?[\d.,]+)',  # Generic
        ]
        
        for pattern in patterns:
            match = re.search(pattern, execution_output, re.IGNORECASE)
            if match:
                value_str = match.group(1).replace(',', '')
                try:
                    return float(value_str)
                except ValueError:
                    continue
        
        # If pattern matching fails, use LLM to extract the value
        if len(execution_output) > 100:  # Only use LLM if there's substantial output
            try:
                prompt = f"""Analyze the following optimization solver output and extract the final optimal objective value.
Look for the final/best objective value that the solver found.

Solver Output:
```
{execution_output}
```

Instructions:
1. Find the final optimal/best objective value from the output
2. Return ONLY the numeric value (e.g., 42.5, -100.0, 1234)
3. If no optimal value can be found, return "NONE"
4. Do not include any explanation, just the number or "NONE"

The optimal value is:"""
                
                response = self._get_response(prompt, temperature=0.0)  # Use low temperature for precision
                response = response.strip()
                
                if response != "NONE" and response:
                    # Try to parse the response as a float
                    try:
                        # Remove any common prefixes/suffixes
                        cleaned = response.replace(':', '').replace('=', '').strip()
                        # Extract first number-like pattern from response
                        number_match = re.search(r'[+-]?[\d.,]+', cleaned)
                        if number_match:
                            value_str = number_match.group(0).replace(',', '')
                            return float(value_str)
                    except ValueError:
                        pass
                        
            except Exception as e:
                print(f"    LLM extraction failed: {e}")
        
        return None
    
    def _majority_vote_selection(self, all_attempts: List[Dict], output_dir: Path) -> Dict:
        """Perform majority vote on multiple attempts"""
        # Extract successful attempts
        successful_attempts = [a for a in all_attempts if a.get('success') and a.get('optimal_value') is not None]
        
        if not successful_attempts:
            return {
                'success': False,
                'error_message': f'All {len(all_attempts)} attempts failed',
                'successful_attempts': 0,
                'total_attempts': len(all_attempts)
            }
        
        # Count optimal values
        optimal_values = [a['optimal_value'] for a in successful_attempts]
        value_counts = Counter(optimal_values)
        
        # Find most common value
        most_common_value, count = value_counts.most_common(1)[0]
        
        # Select best attempt with most common value
        best_attempts = [a for a in successful_attempts if a['optimal_value'] == most_common_value]
        best_attempt = best_attempts[0]  # Take first one
        
        # Determine consensus method
        if len(value_counts) == 1:
            consensus_method = 'unanimous'
        elif count > len(successful_attempts) / 2:
            consensus_method = 'majority'
        else:
            consensus_method = 'plurality'
        
        # Get solver distribution
        solver_counts = Counter([a.get('solver_used', 'unknown') for a in successful_attempts])
        
        # Save voting details
        voting_details = {
            'value_distribution': dict(value_counts),
            'solver_distribution': dict(solver_counts),
            'winning_value': most_common_value,
            'votes_for_winner': count,
            'total_successful': len(successful_attempts),
            'consensus_method': consensus_method,
            'temperature_results': [
                {
                    'attempt': a['attempt_number'],
                    'sql_temp': a.get('sql_temperature', 'N/A'),
                    'code_temp': a.get('code_temperature', 'N/A'),
                    'solver': a.get('solver_used', 'N/A'),
                    'success': a.get('success', False),
                    'value': a.get('optimal_value', 'N/A')
                }
                for a in all_attempts
            ]
        }
        
        with open(output_dir / "voting_results.json", 'w') as f:
            json.dump(voting_details, f, indent=2)
        
        return {
            'success': True,
            'optimal_value': most_common_value,
            'successful_attempts': len(successful_attempts),
            'total_attempts': len(all_attempts),
            'confidence': count / len(successful_attempts),
            'consensus_method': consensus_method,
            'best_execution_output': best_attempt.get('execution_output', ''),
            'best_attempt_dir': best_attempt.get('attempt_dir', ''),
            'winning_solver': best_attempt.get('solver_used', 'unknown'),
            'voting_details': voting_details
        }


def find_problem_descriptions_with_schema(syn_data_dir: str) -> List[Dict]:
    """Find all problem descriptions with schema - original function"""
    problem_files = []
    
    for schema_dir in Path(syn_data_dir).iterdir():
        if schema_dir.is_dir():
            problem_desc_path = schema_dir / "problem_description.md"
            schema_cache_dir = schema_dir / "schema_cache" / "latest"
            schema_sql_path = schema_cache_dir / "schema.sql"
            data_sql_path = schema_cache_dir / "data.sql"
            
            if (problem_desc_path.exists() and 
                schema_sql_path.exists() and 
                data_sql_path.exists()):
                
                problem_files.append({
                    'database_name': schema_dir.name,
                    'problem_path': str(problem_desc_path),
                    'schema_dir': str(schema_dir),
                    'schema_sql_path': str(schema_sql_path),
                    'data_sql_path': str(data_sql_path)
                })
                print(f"Found: {schema_dir.name} (with schema data)")
            else:
                missing = []
                if not problem_desc_path.exists():
                    missing.append("problem_description.md")
                if not schema_sql_path.exists():
                    missing.append("schema.sql")
                if not data_sql_path.exists():
                    missing.append("data.sql")
                print(f"Skip: {schema_dir.name} (missing: {', '.join(missing)})")
    
    return problem_files


def process_single_problem_wrapper(args: Tuple[Dict, Path, str, int, float, float, float, float]) -> Dict:
    """Wrapper for multiprocessing"""
    problem_info, output_dir, model_name, num_attempts, sql_base, sql_inc, code_base, code_inc = args
    solver = IntegratedOptimizationSolver(
        model_name=model_name,
        num_parallel_attempts=num_attempts,
        sql_temp_base=sql_base,
        sql_temp_increment=sql_inc,
        code_temp_base=code_base,
        code_temp_increment=code_inc
    )
    return solver.solve_problem_with_parallel_attempts(problem_info, output_dir)


def main():
    parser = argparse.ArgumentParser(description="Integrated Optimizer with Parallel Attempts and Majority Vote")
    parser.add_argument("--syn_data_dir", type=str, required=True,
                       help="Path to synthetic data directory")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for results")
    parser.add_argument("--max_workers", type=int, default=2,
                       help="Maximum number of parallel workers (for problems)")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--model", type=str, default=None,
                       help="Model name for LLM (ignored for RITS, uses configured models)")
    parser.add_argument("--num_attempts", type=int, default=3,
                       help="Number of parallel attempts per problem")
    
    # Temperature parameters for SQL generation (Stage 1)
    parser.add_argument("--sql_temp_base", type=float, default=0.1,
                       help="Base temperature for SQL generation")
    parser.add_argument("--sql_temp_increment", type=float, default=0.3,
                       help="Temperature increment for each SQL attempt")
    
    # Temperature parameters for code generation (Stage 2)
    parser.add_argument("--code_temp_base", type=float, default=0.1,
                       help="Base temperature for code generation")
    parser.add_argument("--code_temp_increment", type=float, default=0.3,
                       help="Temperature increment for each code attempt")
    
    args = parser.parse_args()
    
    print("=" * 80)
    print("Integrated Optimization Solver with Multi-API and Multi-Solver Support")
    print("=" * 80)
    print(f"Input directory: {args.syn_data_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {RITS_CONFIG['model_name']}")
    print(f"Workers: {args.max_workers}")
    print(f"Attempts per problem: {args.num_attempts}")
    print(f"SQL Temperature: base={args.sql_temp_base}, increment={args.sql_temp_increment}")
    print(f"Code Temperature: base={args.code_temp_base}, increment={args.code_temp_increment}")
    print(f"RITS API Key: {RITS_API_KEY[:10]}...")
    print(f"Model: {RITS_CONFIG['model_name']}")
    print(f"Endpoint: {RITS_CONFIG['base_url']}")
    print(f"Solvers: {', '.join(SOLVERS)}")
    
    # Calculate temperature ranges
    sql_temps = [args.sql_temp_base + i * args.sql_temp_increment for i in range(args.num_attempts)]
    code_temps = [args.code_temp_base + i * args.code_temp_increment for i in range(args.num_attempts)]
    print(f"SQL Temperature Range: {sql_temps}")
    print(f"Code Temperature Range: {code_temps}")
    
    # Find problems
    problem_files = find_problem_descriptions_with_schema(args.syn_data_dir)
    
    if not problem_files:
        print("ERROR: No problems with complete schema data found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} problems with complete data")
    
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to {args.max_problems} problems")
    
    # Create output directory
    output_path = Path(args.output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Prepare arguments
    worker_args = [
        (problem, output_path, args.model, args.num_attempts,
         args.sql_temp_base, args.sql_temp_increment,
         args.code_temp_base, args.code_temp_increment)
        for problem in problem_files
    ]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems
    results = []
    start_time = time.time()
    
    print(f"\nStarting processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        future_to_problem = {
            executor.submit(process_single_problem_wrapper, arg): arg[0]['database_name']
            for arg in worker_args
        }
        
        for future in as_completed(future_to_problem):
            problem_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                success = result.get('final_result', {}).get('success', False)
                progress.update(success=success)
            except Exception as e:
                print(f"\nERROR: {problem_name} failed with exception: {e}")
                results.append({
                    'database_name': problem_name,
                    'status': 'failed',
                    'error': str(e)
                })
                progress.update(success=False)
    
    # Generate summary
    total_time = time.time() - start_time
    successful = [r for r in results if r.get('final_result', {}).get('success', False)]
    failed = [r for r in results if not r.get('final_result', {}).get('success', False)]
    
    # Collect solver statistics
    solver_stats = Counter()
    for r in successful:
        winning_solver = r.get('final_result', {}).get('winning_solver', 'unknown')
        solver_stats[winning_solver] += 1
    
    summary = {
        'run_info': {
            'total_problems': len(problem_files),
            'successful': len(successful),
            'failed': len(failed),
            'success_rate': f"{len(successful)/len(problem_files)*100:.1f}%",
            'total_time': f"{total_time:.1f} seconds",
            'solver_statistics': dict(solver_stats),
            'configuration': {
                'model': RITS_CONFIG['model_name'],
                'num_attempts_per_problem': args.num_attempts,
                'max_workers': args.max_workers,
                'sql_temp_base': args.sql_temp_base,
                'sql_temp_increment': args.sql_temp_increment,
                'code_temp_base': args.code_temp_base,
                'code_temp_increment': args.code_temp_increment,
                'rits_api_key': RITS_API_KEY[:10] + "...",
                'model': RITS_CONFIG['model_name'],
                'endpoint': RITS_CONFIG['base_url'],
                'solvers_used': SOLVERS
            }
        },
        'results': results
    }
    
    with open(output_path / "overall_summary.json", "w") as f:
        json.dump(summary, f, indent=2, default=str)
    
    # Print summary
    print("\n" + "=" * 80)
    print("Processing Complete!")
    print(f"Total problems: {len(problem_files)}")
    print(f"Successful: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Average time per problem: {total_time/len(problem_files):.1f} seconds")
    print(f"\nSolver Statistics:")
    for solver, count in solver_stats.most_common():
        print(f"  - {solver}: {count} problems ({count/len(successful)*100:.1f}%)")
    print(f"\nResults saved to: {args.output_dir}")
    
    # Show examples
    if successful:
        print("\nSuccessful examples:")
        for r in successful[:3]:
            final = r.get('final_result', {})
            print(f"  - {r['database_name']}: {final.get('optimal_value')} "
                  f"(solver: {final.get('winning_solver', 'N/A')}, "
                  f"confidence: {final.get('confidence', 0):.0%}, "
                  f"method: {final.get('consensus_method', 'unknown')})")


if __name__ == "__main__":
    main()