#!/usr/bin/env python3
"""
Integrated Two-Stage Optimization Solver with Parallel Attempts and Majority Vote
Gurobi-Only Version with RITS API Support
"""

import os
import sys
import json
import time
import tempfile
import subprocess
import sqlite3
import pandas as pd
import argparse
import re
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import threading
from collections import Counter
from pathlib import Path
import random
import hashlib
from typing import List, Dict, Tuple, Optional
from openai import OpenAI

# Configuration - RITS API with multiple model options
RITS_API_KEY = ""


MODEL_CONFIGS = {
    "deepseek-v3": {
        "base_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/deepseek-v3-h200/v1",
        "model_name": "deepseek-ai/DeepSeek-V3"  
    },
    "phi-4": {
        "base_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/microsoft-phi-4/v1",
        "model_name": "microsoft/phi-4"
    },
    "qwen2-5-72b": {
        "base_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/qwen2-5-72b-instruct/v1",
       "model_name": "Qwen/Qwen2.5-72B-Instruct"
    },
    "llama-3-3-70b": {
        "base_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/llama-3-3-70b-instruct/v1",
        "model_name": "meta-llama/llama-3-3-70b-instruct"  
    },
    "llama-4-scout": {
        "base_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/llama-4-scout-17b-16e-instruct/v1",
        "model_name": "meta-llama/Llama-4-Scout-17B-16E-Instruct"  
    }
}


# Default model configuration (can be changed via command line)
DEFAULT_MODEL_KEY = "deepseek-v3"
RITS_CONFIG = MODEL_CONFIGS[DEFAULT_MODEL_KEY]

DEFAULT_MODEL = RITS_CONFIG["model_name"]
GUROBI_ENV_PATH = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim"

# Solver types - Only Gurobi for this version
SOLVERS = ['gurobipy']  # Only Gurobi

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

class IntegratedOptimizationSolver:
    """Main solver class that integrates both stages with parallel attempts and majority vote"""
    
    def __init__(self, model_name=None, temperature=0.1):
        """
        Initialize solver with single temperature (Simple Zero-Shot approach)
        
        Args:
            model_name: Model key to use (e.g., 'deepseek-v3', 'phi-4-reasoning', etc.)
            temperature: Temperature for both SQL and code generation
        """
        self.model_key = model_name if model_name and model_name in MODEL_CONFIGS else DEFAULT_MODEL_KEY
        self.temperature = temperature
        self.python_executable = os.path.join(GUROBI_ENV_PATH, "bin", "python")
        
    def _setup_rits_client(self) -> Tuple[OpenAI, str]:
        """Setup RITS API client with selected model configuration"""
        config = MODEL_CONFIGS[self.model_key]
        
        client = OpenAI(
            api_key="dummy",  # RITS uses header authentication
            base_url=config["base_url"],
            default_headers={"RITS_API_KEY": RITS_API_KEY},
            timeout=3000
        )
        
        return client, config["model_name"]
    
    def _get_response(self, prompt: str, temperature: float = 0.1, max_retries: int = 3) -> str:
        """Call RITS API with retry logic"""
        for attempt in range(max_retries):
            try:
                # Get client for RITS API
                client, model_name = self._setup_rits_client()
                
                response = client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=4096,
                    temperature=temperature,
                    top_p=0.9,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    stream=False
                )
                
                if response.choices and response.choices[0].message.content:
                    return response.choices[0].message.content
                else:
                    print(f"WARNING: Empty response on attempt {attempt + 1}")
                    
            except Exception as e:
                print(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    print(f"Retrying in {2 ** attempt} seconds...")
                    time.sleep(2 ** attempt)
                else:
                    raise e
        
        return "No response generated after multiple attempts."
    

    
    def solve_problem_single_attempt(self, problem_info: Dict, output_dir: Path) -> Dict:
        """Solve problem with single attempt (Simple Zero-Shot approach)"""
        database_name = problem_info['database_name']
        print(f"\n{'='*60}")
        print(f"Processing: {database_name} with single attempt")
        print(f"Temperature: {self.temperature}")
        print(f"Solver: gurobipy")
        print(f"{'='*60}")
        
        # Create main output directory
        main_output_dir = output_dir / database_name
        main_output_dir.mkdir(parents=True, exist_ok=True)
        
        # Check if already processed
        final_output_file = main_output_dir / "code_output.txt"
        if final_output_file.exists():
            print(f"Already processed: {database_name}")
            return {
                'database_name': database_name,
                'status': 'skipped',
                'reason': 'Already processed'
            }
        
        # Run single attempt
        try:
            attempt_result = self._single_attempt_pipeline(
                problem_info,
                main_output_dir,
                0  # single attempt number
            )
            
            # Save final result
            with open(final_output_file, 'w') as f:
                if attempt_result['success']:
                    f.write(f"Optimal Objective Value: {attempt_result['optimal_value']}\n")
                    
                    if attempt_result.get('execution_output'):
                        f.write("\n" + "="*50 + "\n")
                        f.write("Full Gurobi Output:\n")
                        f.write("="*50 + "\n")
                        f.write(attempt_result['execution_output'])
                else:
                    f.write(f"ERROR: Optimization failed\n")
                    f.write(f"Reason: {attempt_result.get('error', 'Unknown error')}\n")
            
            # Save summary
            summary = {
                'database_name': database_name,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'temperature': self.temperature,
                'attempt_result': attempt_result
            }
            
            with open(main_output_dir / "summary.json", 'w') as f:
                json.dump(summary, f, indent=2, default=str)
            
            return summary
            
        except Exception as e:
            print(f"ERROR: {database_name} failed with exception: {e}")
            
            # Save error result
            with open(final_output_file, 'w') as f:
                f.write(f"ERROR: {str(e)}\n")
            
            return {
                'database_name': database_name,
                'status': 'failed',
                'error': str(e),
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
            }
    
    def _single_attempt_pipeline(self, problem_info: Dict, output_dir: Path, attempt_num: int) -> Dict:
        """Execute a single attempt of the full pipeline"""
        result = {
            'success': False,
            'optimal_value': None,
            'stage1_success': False,
            'stage2_success': False,
            'temperature': self.temperature,
            'solver_used': 'gurobipy'
        }
        
        try:
            # Stage 1: SQL Data Retrieval
            print(f"Stage 1: SQL Data Retrieval (temp={self.temperature:.1f})")
            
            enhanced_problem_path = self._stage1_sql_retrieval(
                problem_info, 
                output_dir, 
                self.temperature
            )
            
            if enhanced_problem_path:
                result['stage1_success'] = True
                result['enhanced_problem_path'] = str(enhanced_problem_path)
                
                # Stage 2: Code Generation with Gurobi
                print(f"Stage 2: Code Generation (gurobipy, temp={self.temperature:.1f})")
                
                code_result = self._stage2_code_generation(
                    enhanced_problem_path,
                    output_dir,
                    self.temperature,
                    'gurobipy'
                )
                
                if code_result['success']:
                    result['stage2_success'] = True
                    result['success'] = True
                    result['optimal_value'] = code_result['optimal_value']
                    result['execution_output'] = code_result.get('execution_output', '')
                    result['generated_code'] = code_result.get('generated_code', '')
                else:
                    result['stage2_error'] = code_result.get('error_message', 'Unknown error')
            else:
                result['stage1_error'] = 'Failed to generate enhanced problem description'
                
        except Exception as e:
            result['error'] = str(e)
            print(f"ERROR: {str(e)}")
        
        return result
    
    def _stage1_sql_retrieval(self, problem_info: Dict, output_dir: Path, temperature: float) -> Optional[Path]:
        """Stage 1: SQL data retrieval"""
        try:
            # Extract problem context without stored values
            problem_context = self._extract_problem_context_without_stored_values(
                problem_info['problem_path']
            )
            
            # Read schema
            with open(problem_info['schema_sql_path'], 'r', encoding='utf-8') as f:
                schema_sql = f.read()
            
            # Generate SQL queries with LLM
            llm_response = self._generate_sql_queries_with_llm(
                problem_context, 
                schema_sql, 
                temperature
            )
            
            if not llm_response:
                return None
            
            # Save LLM response
            with open(output_dir / "stage1_llm_sql_generation.txt", "w", encoding="utf-8") as f:
                f.write(f"=== Problem Context ===\n{problem_context}\n\n")
                f.write(f"=== Schema ===\n{schema_sql}\n\n")
                f.write(f"=== Temperature ===\n{temperature}\n\n")
                f.write(f"=== LLM Response ===\n{llm_response}\n")
            
            # Extract queries
            queries = self._extract_sql_queries_from_response(llm_response)
            print(f"    Generated {len(queries)} SQL queries")
            
            if not queries:
                return None
            
            # Create database and execute queries
            conn = self._create_database_from_schema(
                problem_info['schema_sql_path'],
                problem_info['data_sql_path']
            )
            
            if not conn:
                return None
            
            query_results = self._execute_sql_queries(conn, queries)
            conn.close()
            
            # Create enhanced problem description
            enhanced_content = self._create_enhanced_problem_description(
                problem_info['problem_path'],
                query_results
            )
            
            # Save enhanced problem description
            enhanced_path = output_dir / "enhanced_problem_description.md"
            with open(enhanced_path, "w", encoding="utf-8") as f:
                f.write(enhanced_content)
            
            # Save stage 1 results
            stage1_results = {
                "queries_generated": len(queries),
                "queries_executed": len(query_results),
                "temperature": temperature,
                "query_results": [
                    {
                        "comment": r['comment'],
                        "query": r['query'],
                        "rows_returned": len(r['result_df']) if 'result_df' in r else 0,
                        "error": r.get('error')
                    }
                    for r in query_results
                ]
            }
            
            with open(output_dir / "stage1_results.json", "w", encoding="utf-8") as f:
                json.dump(stage1_results, f, indent=2)
            
            return enhanced_path
            
        except Exception as e:
            print(f"    Stage 1 Error: {e}")
            return None
    
    def _extract_problem_context_without_stored_values(self, problem_description_path: str) -> str:
        """Extract problem description without 'Current Stored Values' section"""
        with open(problem_description_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        lines = content.split('\n')
        filtered_lines = []
        skip_section = False
        
        for line in lines:
            if '### Current Stored Values' in line or '### current stored values' in line.lower():
                skip_section = True
                continue
            
            if skip_section and line.startswith('### ') and 'current stored values' not in line.lower():
                skip_section = False
            
            if not skip_section:
                filtered_lines.append(line)
        
        return '\n'.join(filtered_lines)
    
    def _generate_sql_queries_with_llm(self, problem_context: str, schema_sql: str, temperature: float) -> str:
        """Generate SQL queries with LLM"""
        prompt = f"""You are an expert database analyst helping with optimization problem data retrieval. Based on the problem description and database schema, analyze what data would be most useful for solving this optimization problem and generate appropriate SQL SELECT queries.

**Problem Context (without stored values):**
{problem_context}

**Database Schema:**
{schema_sql}

**Your Task:** 
Carefully analyze the optimization problem and determine what data from the database would be most relevant. Then generate SQL SELECT queries to retrieve this data.

**Analysis Guidelines:**
- Identify what data is needed for decision variables (what needs to be optimized)
- Identify what data is needed for objective function coefficients (what to maximize/minimize)
- Identify what data is needed for constraint parameters (limitations and requirements)
- Consider what aggregated data, summary statistics, or lookup information might be helpful
- Think about relationships between tables and potential joins
- Consider filtering criteria that would make the data more relevant

**Query Requirements:**
- Generate as many queries as you think are necessary and useful (no fixed number)
- Each query should have a clear purpose for the optimization problem
- Include meaningful comments explaining what each query retrieves and why it's relevant
- Use proper SQL syntax with table names, column names, and joins as needed
- Consider both detailed data and summary/aggregated data where appropriate

**Output Format:**
```sql
-- Query Description: Explain what this retrieves and why it's important for optimization
SELECT ... FROM ... WHERE ... ;

-- Query Description: Explain what this retrieves and why it's important for optimization  
SELECT ... FROM ... WHERE ... ;

-- Continue with additional queries as needed
```

Analyze the problem and generate the most relevant SQL queries:"""
        
        try:
            response = self._get_response(prompt, temperature)
            return response
        except Exception as e:
            print(f"    ERROR generating SQL queries: {e}")
            return None
    
    def _extract_sql_queries_from_response(self, llm_response: str) -> List[Dict]:
        """Extract SQL queries from LLM response"""
        queries = []
        
        sql_blocks = re.findall(r'```sql\s*(.*?)\s*```', llm_response, re.DOTALL)
        
        if sql_blocks:
            sql_content = sql_blocks[0]
        else:
            sql_content = llm_response
        
        lines = sql_content.split('\n')
        current_query = []
        current_comment = ""
        
        for line in lines:
            line = line.strip()
            
            if line.startswith('--'):
                if current_query:
                    query_text = ' '.join(current_query).strip()
                    if query_text.upper().startswith('SELECT') and query_text.endswith(';'):
                        queries.append({
                            'comment': current_comment,
                            'query': query_text
                        })
                    current_query = []
                
                current_comment = line[2:].strip()
            
            elif line.upper().startswith('SELECT'):
                current_query = [line]
            
            elif current_query and line:
                current_query.append(line)
                
                if line.endswith(';'):
                    query_text = ' '.join(current_query).strip()
                    queries.append({
                        'comment': current_comment,
                        'query': query_text
                    })
                    current_query = []
        
        if current_query:
            query_text = ' '.join(current_query).strip()
            if query_text.upper().startswith('SELECT'):
                if not query_text.endswith(';'):
                    query_text += ';'
                queries.append({
                    'comment': current_comment,
                    'query': query_text
                })
        
        return queries
    
    def _create_database_from_schema(self, schema_sql_path: str, data_sql_path: str) -> Optional[sqlite3.Connection]:
        """Create in-memory SQLite database"""
        with open(schema_sql_path, 'r', encoding='utf-8') as f:
            schema_sql = f.read()
        
        with open(data_sql_path, 'r', encoding='utf-8') as f:
            data_sql = f.read()
        
        conn = sqlite3.connect(':memory:')
        cursor = conn.cursor()
        
        try:
            schema_statements = [stmt.strip() for stmt in schema_sql.split(';') if stmt.strip()]
            for stmt in schema_statements:
                if stmt.strip():
                    cursor.execute(stmt)
            
            data_statements = [stmt.strip() for stmt in data_sql.split(';') if stmt.strip()]
            for stmt in data_statements:
                if stmt.strip():
                    cursor.execute(stmt)
            
            conn.commit()
            return conn
            
        except Exception as e:
            print(f"    ERROR creating database: {e}")
            conn.close()
            return None
    
    def _execute_sql_queries(self, conn: sqlite3.Connection, queries: List[Dict]) -> List[Dict]:
        """Execute SQL queries"""
        results = []
        
        for i, query_info in enumerate(queries):
            comment = query_info['comment']
            query = query_info['query']
            
            try:
                df = pd.read_sql_query(query, conn)
                
                results.append({
                    'comment': comment,
                    'query': query,
                    'result_df': df,
                    'result_csv': df.to_csv(index=False)
                })
                
                print(f"    Query {i+1}: {comment} -> {len(df)} rows")
                
            except Exception as e:
                print(f"    ERROR executing query {i+1}: {e}")
                print(f"    Query: {query}")
                
                results.append({
                    'comment': comment,
                    'query': query,
                    'result_df': pd.DataFrame(),
                    'result_csv': "",
                    'error': str(e)
                })
        
        return results
    
    def _create_enhanced_problem_description(self, original_problem_path: str, 
                                           query_results: List[Dict]) -> str:
        """Create enhanced problem description"""
        with open(original_problem_path, 'r', encoding='utf-8') as f:
            original_content = f.read()
        
        enhanced_content = self._extract_problem_context_without_stored_values(original_problem_path)
        
        enhanced_content += "\n\n### Retrieved Values\n\n"
        
        for i, result in enumerate(query_results):
            comment = result['comment']
            query = result['query']
            csv_data = result['result_csv']
            
            enhanced_content += f"**Query {i+1}: {comment}**\n\n"
            enhanced_content += f"```sql\n{query}\n```\n\n"
            
            if csv_data and not result.get('error'):
                enhanced_content += f"**Results (CSV format):**\n```csv\n{csv_data}```\n\n"
            else:
                error_msg = result.get('error', 'No data returned')
                enhanced_content += f"**Error:** {error_msg}\n\n"
        
        return enhanced_content
    
    def _stage2_code_generation(self, enhanced_problem_path: Path, output_dir: Path, 
                               temperature: float, solver_type: str) -> Dict:
        """Stage 2: Code generation with Gurobi (enhanced template-based prompt)"""
        result = {
            'success': False,
            'optimal_value': None,
            'generated_code': None,
            'execution_output': None,
            'error_message': None,
            'temperature': temperature,
            'solver_type': solver_type
        }
        
        try:
            # Read enhanced problem description
            with open(enhanced_problem_path, 'r', encoding='utf-8') as f:
                enhanced_problem_text = f.read()
            
            # Generate Gurobi-specific prompt with enhanced template
            prompt = self._generate_gurobi_specific_prompt(enhanced_problem_text)
            
            # Call LLM
            llm_response = self._get_response(prompt, temperature)
            
            # Extract Python code
            python_code = self._extract_python_code(llm_response)
            if not python_code:
                result['error_message'] = "Failed to extract Python code from LLM response"
                return result
            
            result['generated_code'] = python_code
            
            # Save generated code
            with open(output_dir / f"generated_{solver_type}_code.py", "w", encoding="utf-8") as f:
                f.write(python_code)
            
            # Execute Gurobi code
            execution_success, execution_output = self._execute_solver_code(python_code, solver_type)
            result['execution_output'] = execution_output
            
            # Save execution output
            with open(output_dir / f"{solver_type}_execution_output.txt", "w", encoding="utf-8") as f:
                f.write(execution_output)
            
            if execution_success:
                optimal_value = self._extract_optimal_value(execution_output)
                if optimal_value is not None:
                    result['optimal_value'] = optimal_value
                    result['success'] = True
                    print(f"    Successfully extracted optimal value: {optimal_value}")
                else:
                    result['error_message'] = "Could not extract optimal value from output"
                    print(f"    WARNING: Failed to extract optimal value from {solver_type} output")
                    # Save problematic output for debugging
                    with open(output_dir / f"{solver_type}_extraction_failed.txt", "w") as f:
                        f.write("Failed to extract optimal value from:\n\n")
                        f.write(execution_output)
            else:
                result['error_message'] = f"Code execution failed: {execution_output}"
            
            # Save stage 2 results
            stage2_results = {
                "success": result['success'],
                "optimal_value": result['optimal_value'],
                "temperature": temperature,
                "solver_type": solver_type,
                "error_message": result.get('error_message')
            }
            
            with open(output_dir / "stage2_zero_shot_results.json", "w", encoding="utf-8") as f:
                json.dump(stage2_results, f, indent=2)
            
        except Exception as e:
            result['error_message'] = f"Unexpected error: {str(e)}"
        
        return result
    
    def _generate_gurobi_specific_prompt(self, problem_description: str) -> str:
        """Generate simple zero-shot prompt for Gurobi code generation"""
        
        prompt = f"""You are an expert in optimization and Gurobi programming. Given the following optimization problem description, generate complete Python code using Gurobi to solve it.

**Problem Description:**
{problem_description}

**Requirements:**
1. Generate complete, executable Python code using Gurobi (gurobipy)
2. Include all necessary imports
3. Define variables, constraints, and objective function correctly
4. Include model optimization and result printing
5. Print the optimal objective value clearly at the end
6. Handle potential errors gracefully
7. Make sure the code can run independently

**Output Format:**
```python
# Your complete Gurobi Python code here
import gurobipy as gp
from gurobipy import GRB

# Model creation and solving code
# ...

# Print results clearly
print(f"Optimal Objective Value: {{optimal_value}}")
```

**Important Notes:**
- Focus on creating a working solution
- Use appropriate variable types (continuous, integer, binary)
- Include proper constraint formulations
- Ensure the objective function is correctly defined
- Add comments for clarity

Generate the complete Gurobi Python code:"""

        return prompt
    
    def _extract_python_code(self, llm_response: str) -> Optional[str]:
        """Extract Python code from LLM response"""
        code_blocks = re.findall(r'```python\s*(.*?)\s*```', llm_response, re.DOTALL)
        
        if code_blocks:
            return code_blocks[0].strip()
        
        # Fallback
        lines = llm_response.split('\n')
        code_lines = []
        in_code = False
        
        for line in lines:
            if 'import' in line and any(solver in line.lower() for solver in ['gurobi', 'gurobipy']):
                in_code = True
            
            if in_code:
                code_lines.append(line)
        
        if code_lines:
            return '\n'.join(code_lines).strip()
        
        return None
    
    def _execute_solver_code(self, python_code: str, solver_type: str) -> Tuple[bool, str]:
        """Execute Gurobi code"""
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix=".py", delete=False) as temp_file:
                temp_file.write(python_code)
                temp_file_path = temp_file.name
            
            result = subprocess.run(
                [self.python_executable, temp_file_path],
                capture_output=True,
                text=True,
                timeout=300
            )
            
            os.unlink(temp_file_path)
            
            if result.returncode == 0:
                return True, result.stdout
            else:
                return False, result.stderr
                
        except subprocess.TimeoutExpired:
            return False, "Execution timeout after 300 seconds"
        except Exception as e:
            return False, f"Execution error: {str(e)}"
    
    def _extract_optimal_value(self, execution_output: str) -> Optional[float]:
        """Extract optimal value - enhanced with LLM fallback"""
        # First try pattern matching
        patterns = [
            r'Optimal Objective Value:\s*([+-]?[\d.,]+)',
            r'Best objective\s*([+-]?[\d.,]+)',
            r'Optimal value:\s*([+-]?[\d.,]+)',
            r'objective\s*([+-]?[\d.,]+)',
            r'Objective value:\s*([+-]?[\d.,]+)',
            r'Obj:\s*([+-]?[\d.,]+)',              # Gurobi short form
            r'Optimal solution found.*?([+-]?[\d.,]+)',  # Generic
        ]
        
        for pattern in patterns:
            match = re.search(pattern, execution_output, re.IGNORECASE)
            if match:
                value_str = match.group(1).replace(',', '')
                try:
                    return float(value_str)
                except ValueError:
                    continue
        
        # If pattern matching fails, use LLM to extract the value
        if len(execution_output) > 100:  # Only use LLM if there's substantial output
            try:
                prompt = f"""Analyze the following Gurobi optimization solver output and extract the final optimal objective value.
Look for the final/best objective value that the solver found.

Gurobi Output:
```
{execution_output}
```

Instructions:
1. Find the final optimal/best objective value from the output
2. Return ONLY the numeric value (e.g., 42.5, -100.0, 1234)
3. If no optimal value can be found, return "NONE"
4. Do not include any explanation, just the number or "NONE"

The optimal value is:"""
                
                response = self._get_response(prompt, temperature=0.0)  # Use low temperature for precision
                response = response.strip()
                
                if response != "NONE" and response:
                    # Try to parse the response as a float
                    try:
                        # Remove any common prefixes/suffixes
                        cleaned = response.replace(':', '').replace('=', '').strip()
                        # Extract first number-like pattern from response
                        number_match = re.search(r'[+-]?[\d.,]+', cleaned)
                        if number_match:
                            value_str = number_match.group(0).replace(',', '')
                            return float(value_str)
                    except ValueError:
                        pass
                        
            except Exception as e:
                print(f"    LLM extraction failed: {e}")
        
        return None
    



def find_problem_descriptions_with_schema(syn_data_dir: str) -> List[Dict]:
    """Find all problem descriptions with schema"""
    problem_files = []
    
    for schema_dir in Path(syn_data_dir).iterdir():
        if schema_dir.is_dir():
            problem_desc_path = schema_dir / "problem_description.md"
            schema_cache_dir = schema_dir / "schema_cache" / "latest"
            schema_sql_path = schema_cache_dir / "schema.sql"
            data_sql_path = schema_cache_dir / "data.sql"
            
            if (problem_desc_path.exists() and 
                schema_sql_path.exists() and 
                data_sql_path.exists()):
                
                problem_files.append({
                    'database_name': schema_dir.name,
                    'problem_path': str(problem_desc_path),
                    'schema_dir': str(schema_dir),
                    'schema_sql_path': str(schema_sql_path),
                    'data_sql_path': str(data_sql_path)
                })
                print(f"Found: {schema_dir.name} (with schema data)")
            else:
                missing = []
                if not problem_desc_path.exists():
                    missing.append("problem_description.md")
                if not schema_sql_path.exists():
                    missing.append("schema.sql")
                if not data_sql_path.exists():
                    missing.append("data.sql")
                print(f"Skip: {schema_dir.name} (missing: {', '.join(missing)})")
    
    return problem_files


def process_single_problem_wrapper(args: Tuple[Dict, Path, str, float]) -> Dict:
    """Wrapper for multiprocessing"""
    problem_info, output_dir, model_key, temperature = args
    solver = IntegratedOptimizationSolver(
        model_name=model_key,
        temperature=temperature
    )
    return solver.solve_problem_single_attempt(problem_info, output_dir)


def main():
    parser = argparse.ArgumentParser(description="Simple Zero-Shot Integrated Optimizer - Gurobi Only with RITS API")
    parser.add_argument("--syn_data_dir", type=str, required=True,
                       help="Path to synthetic data directory")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for results")
    parser.add_argument("--max_workers", type=int, default=2,
                       help="Maximum number of parallel workers (for problems)")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--model", type=str, default=DEFAULT_MODEL_KEY,
                       choices=list(MODEL_CONFIGS.keys()),
                       help=f"Model to use. Options: {', '.join(MODEL_CONFIGS.keys())}")
    parser.add_argument("--temperature", type=float, default=0.1,
                       help="Temperature for both SQL and code generation")
    
    args = parser.parse_args()
    
    # Get model config for display
    model_config = MODEL_CONFIGS[args.model]
    
    print("=" * 80)
    print("Simple Zero-Shot Integrated Optimization Solver - Gurobi Only with RITS API")
    print("=" * 80)
    print(f"Input directory: {args.syn_data_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {args.model} ({model_config['model_name']})")
    print(f"Workers: {args.max_workers}")
    print(f"Temperature: {args.temperature}")
    print(f"RITS API Key: {RITS_API_KEY[:10]}...")
    print(f"Endpoint: {model_config['base_url']}")
    print(f"Solver: gurobipy ONLY")
    
    # Find problems
    problem_files = find_problem_descriptions_with_schema(args.syn_data_dir)
    
    if not problem_files:
        print("ERROR: No problems with complete schema data found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} problems with complete data")
    
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to {args.max_problems} problems")
    
    # Create output directory
    output_path = Path(args.output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Prepare arguments
    worker_args = [
        (problem, output_path, args.model, args.temperature)
        for problem in problem_files
    ]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems
    results = []
    start_time = time.time()
    
    print(f"\nStarting processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        future_to_problem = {
            executor.submit(process_single_problem_wrapper, arg): arg[0]['database_name']
            for arg in worker_args
        }
        
        for future in as_completed(future_to_problem):
            problem_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                success = result.get('attempt_result', {}).get('success', False) if result.get('status') != 'failed' else False
                progress.update(success=success)
            except Exception as e:
                print(f"\nERROR: {problem_name} failed with exception: {e}")
                results.append({
                    'database_name': problem_name,
                    'status': 'failed',
                    'error': str(e)
                })
                progress.update(success=False)
    
    # Generate summary
    total_time = time.time() - start_time
    successful = [r for r in results if r.get('attempt_result', {}).get('success', False) and r.get('status') != 'failed']
    failed = [r for r in results if not (r.get('attempt_result', {}).get('success', False) and r.get('status') != 'failed')]
    
    summary = {
        'run_info': {
            'total_problems': len(problem_files),
            'successful': len(successful),
            'failed': len(failed),
            'success_rate': f"{len(successful)/len(problem_files)*100:.1f}%",
            'total_time': f"{total_time:.1f} seconds",
            'configuration': {
                'study_type': 'simple_zero_shot_gurobi_only_with_rits_api',
                'model_key': args.model,
                'model_name': model_config['model_name'],
                'base_url': model_config['base_url'],
                'temperature': args.temperature,
                'max_workers': args.max_workers,
                'rits_api_key': RITS_API_KEY[:10] + "...",
                'solvers_used': ['gurobipy']
            }
        },
        'results': results
    }
    
    with open(output_path / "overall_summary.json", "w") as f:
        json.dump(summary, f, indent=2, default=str)
    
    # Print summary
    print("\n" + "=" * 80)
    print("Processing Complete! (Simple Zero-Shot Gurobi Only with RITS API)")
    print(f"Total problems: {len(problem_files)}")
    print(f"Successful: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Average time per problem: {total_time/len(problem_files):.1f} seconds")
    print(f"Model used: {args.model} ({model_config['model_name']})")
    print(f"\nResults saved to: {args.output_dir}")
    
    # Show examples
    if successful:
        print("\nSuccessful examples:")
        for r in successful[:3]:
            attempt_result = r.get('attempt_result', {})
            print(f"  - {r['database_name']}: {attempt_result.get('optimal_value', 'N/A')}")


if __name__ == "__main__":
    main()