#!/usr/bin/env python3
"""
Simplified SQL generator that uses pre-generated system prompts.
This replaces the complex artifact-based system with direct prompt usage.
"""

import json
import argparse
from typing import Dict, List, Optional, Tuple
import anthropic
import logging
from pathlib import Path
import os
import re
from datetime import datetime
import sys

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
from utilities.cached_sql_executor import CachedSQLExecutor, SQLTimeoutError

# Handle both module and script execution
try:
    from .config import (
        MAX_TOKENS, 
        API_KEY_ENV_VAR,
        SUPPORTED_MODELS,
        DEFAULT_MODEL
    )
except ImportError:
    # When run as a script, import from RoboPhD
    from RoboPhD.config import (
        MAX_TOKENS, 
        API_KEY_ENV_VAR,
        SUPPORTED_MODELS,
        DEFAULT_MODEL
    )

class PromptBasedSQLGenerator:
    def __init__(self, prompt_path: str, model: str = DEFAULT_MODEL, api_key: Optional[str] = None, use_evidence: bool = True, db_path: Optional[str] = None, sql_validation_timeout: int = 30, verification_retries: int = 2, temperature_strategy: str = "progressive"):
        """
        Initialize SQL generator with a pre-generated system prompt.

        Args:
            prompt_path: Path to the system prompt file generated by subagent
            model: Model to use for SQL generation
            api_key: Optional API key
            use_evidence: Whether to include evidence in prompts (default True)
            db_path: Path to database file for SQL validation (optional)
            sql_validation_timeout: Timeout in seconds for SQL validation (default 30)
            verification_retries: Number of verification attempts (default 2, 0 = current behavior)
            temperature_strategy: Temperature strategy for retries ("progressive", "fixed", "adaptive")
        """
        # Load pre-generated system prompt
        with open(prompt_path, 'r') as f:
            self.system_prompt = f.read()
        
        # Initialize Anthropic client
        api_key = api_key or os.getenv(API_KEY_ENV_VAR)
        if not api_key:
            raise ValueError(f"API key required. Set {API_KEY_ENV_VAR} or pass --api_key")
        
        self.client = anthropic.Anthropic(api_key=api_key)
        
        # Model configuration
        if model not in SUPPORTED_MODELS:
            raise ValueError(f"Unsupported model: {model}. Choose from: {list(SUPPORTED_MODELS.keys())}")
        
        self.model_key = model
        self.model_config = SUPPORTED_MODELS[model]
        self.model_name = self.model_config['name']
        self.pricing = self.model_config['pricing']
        
        # Track token usage
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.cache_creation_tokens = 0
        self.cache_read_tokens = 0
        self.total_cost = 0.0
        self.use_evidence = use_evidence
        
        # SQL validation setup
        self.db_path = db_path
        self.sql_validation_timeout = sql_validation_timeout
        self.sql_executor = CachedSQLExecutor() if db_path else None

        # Verification settings
        self.verification_retries = verification_retries
        self.temperature_strategy = temperature_strategy
        
        # Track validation statistics
        self.validation_stats = {
            'total_generated': 0,
            'validation_attempted': 0,
            'validation_passed': 0,
            'validation_failed': 0,
            'validation_failed_empty': 0,
            'retry_attempted': 0,
            'retry_succeeded': 0,
            'timeout_errors': 0,
            # New verification stats
            'verification_attempted': 0,
            'verification_succeeded': 0,
            'verification_failed': 0,
            'verification_attempts_total': 0,
            # Enhanced tracking
            'verification_passed_immediately': 0,
            'verification_improved': 0,
            'total_api_calls': 0,
            'generation_api_calls': 0,
            'verification_api_calls': 0
        }
        
    def generate_sql(self, questions: List[Dict], db_name: str) -> Dict:
        """
        Generate SQL for questions using the pre-generated system prompt.
        
        Args:
            questions: List of question dictionaries
            db_name: Name of the database
            
        Returns:
            Dictionary with predictions and metadata
        """
        results = []
        # Track non-critical failures (critical ones now raise exceptions immediately)
        api_failures = {
            'other_errors': 0
        }
        
        print(f"Processing {len(questions)} questions for database {db_name}")
        print(f"Using model: {self.model_key} ({self.model_name})")
        
        for i, question in enumerate(questions, 1):
            # Handle both dev format (with question_id) and train format (without)
            # For train data, use the original index in the full dataset
            question_id = question.get('question_id', question.get('_original_index', i-1))
            print(f"  Question {i}/{len(questions)}: {question_id}")
            
            sql, verification_info = self._generate_single_sql(question)

            # Check for critical API failures that should stop processing
            if "API_CREDIT_EXHAUSTION" in sql:
                print(f"\n❌ CRITICAL ERROR: API Credit Exhaustion")
                print(f"   Cannot continue - please add credits to your API key")
                raise RuntimeError("API_CREDIT_EXHAUSTION: Cannot continue without API credits")
            elif "API_RATE_LIMIT" in sql:
                print(f"\n❌ CRITICAL ERROR: API Rate Limit Hit")
                print(f"   Cannot continue - please wait before retrying")
                raise RuntimeError("API_RATE_LIMIT: Rate limit exceeded, cannot continue")
            elif "API_AUTH_ERROR" in sql:
                print(f"\n❌ CRITICAL ERROR: API Authentication Failed")
                print(f"   Cannot continue - please check your API key configuration")
                raise RuntimeError("API_AUTH_ERROR: Authentication failed, cannot continue")
            elif "ERROR:" in sql:
                # Other errors we just track but don't fail immediately
                api_failures['other_errors'] += 1

            result_entry = {
                'question_id': question_id,
                'question': question['question'],
                'evidence': question.get('evidence', ''),
                'predicted_sql': sql,
                'db_id': db_name
            }

            # Add verification info if present
            if verification_info:
                result_entry['verification_info'] = verification_info

            results.append(result_entry)
        
        # Create BIRD format output
        bird_predictions = {}
        for result in results:
            bird_sql = f"{result['predicted_sql']}\t----- bird -----\t{db_name}"
            bird_predictions[str(result['question_id'])] = bird_sql
        
        # Print non-critical failure summary if any occurred
        # Note: Critical API failures (credit/rate/auth) now raise exceptions immediately
        if api_failures['other_errors'] > 0:
            print(f"\n⚠️  Non-critical failures: {api_failures['other_errors']}/{len(questions)} queries had errors")
            print(f"    These queries returned error messages but processing continued")
        
        # Calculate cost metrics
        cost_without_caching = (
            (self.total_input_tokens + self.cache_creation_tokens + self.cache_read_tokens) / 1000000
        ) * self.pricing['input'] + (self.total_output_tokens / 1000000) * self.pricing['output']
        
        cost_savings = cost_without_caching - self.total_cost
        savings_percentage = (cost_savings / cost_without_caching * 100) if cost_without_caching > 0 else 0
        
        return {
            'predictions': bird_predictions,
            'detailed_results': results,
            'metadata': {
                'database': db_name,
                'prompt_source': 'subagent-generated',
                'generator_version': 'prompt-based-v1',
                'model': {
                    'name': self.model_name,
                    'key': self.model_key,
                    'pricing': self.pricing
                },
                'total_questions': len(questions),
                'token_usage': {
                    'input_tokens': self.total_input_tokens,
                    'output_tokens': self.total_output_tokens,
                    'cache_creation_tokens': self.cache_creation_tokens,
                    'cache_read_tokens': self.cache_read_tokens,
                    'total_cost': self.total_cost,
                    'cost_savings': {
                        'amount': cost_savings,
                        'percentage': savings_percentage
                    }
                },
                'api_failures': api_failures,
                'validation_stats': self.validation_stats,
                'sql_validation_enabled': self.db_path is not None,
                'verification_settings': {
                    'verification_retries': self.verification_retries,
                    'temperature_strategy': self.temperature_strategy,
                    'verification_enabled': self.verification_retries > 0
                },
                'timestamp': str(datetime.now()),
            }
        }
    
    def _validate_sql(self, sql: str) -> Tuple[bool, Optional[str]]:
        """
        Validate SQL by attempting to execute it against the database.
        
        Args:
            sql: SQL query to validate
            
        Returns:
            Tuple of (is_valid, error_message)
        """
        if not self.sql_executor or not self.db_path:
            return True, None  # Skip validation if not configured
        
        self.validation_stats['validation_attempted'] += 1
        
        try:
            # Attempt to execute the SQL with timeout
            result = self.sql_executor.execute_sql(
                sql=sql,
                db_path=self.db_path,
                is_ground_truth=False,
                timeout_seconds=self.sql_validation_timeout
            )

            # Check if the result is empty (no rows returned)
            if self._is_result_empty(result):
                self.validation_stats['validation_failed_empty'] += 1
                return False, "empty"

            self.validation_stats['validation_passed'] += 1
            return True, None
            
        except SQLTimeoutError as e:
            self.validation_stats['timeout_errors'] += 1
            self.validation_stats['validation_failed'] += 1
            return False, f"SQL execution timed out after {self.sql_validation_timeout}s: {str(e)}"
            
        except Exception as e:
            self.validation_stats['validation_failed'] += 1
            return False, f"SQL validation error: {str(e)}"

    def _is_result_empty(self, result):
        if result is None or len(result) == 0:
            return True
        if result == [[None]] or result == [[]] or result == [[0]]:
            return True
        
        return False

    def _summarize_results(self, results, error_msg=None):
        """
        Summarize SQL execution results for verification prompt.

        Args:
            results: SQL execution results (list of rows)
            error_msg: SQL error message if execution failed

        Returns:
            String summary of results for verification
        """
        if error_msg:
            return f"SQL Error: {error_msg}"

        if results is None or len(results) == 0:
            return "Empty result set (0 rows)"

        # Preserve original order - no set() deduplication!
        if len(results) <= 10:
            return f"Complete results ({len(results)} rows):\n{results}"
        else:
            sample = results[:10]
            return f"Result summary: {len(results)} total rows\nFirst 10 rows:\n{sample}\n[...{len(results)-10} more rows]"

    def _build_verification_prompt(self, question, sql, summary, attempts):
        """
        Build verification prompt with complete history.

        Args:
            question: Question dictionary
            sql: Current SQL query
            summary: Results summary from _summarize_results()
            attempts: List of previous attempt dictionaries

        Returns:
            String prompt for verification
        """
        history = ""
        if attempts:
            history = "Previous attempts:\n\n"
            for i, attempt in enumerate(attempts, 1):
                history += f"Attempt {i}:\n"
                history += f"SQL: {attempt['sql']}\n"
                history += f"Results: {attempt['summary']}\n"
                if attempt.get('feedback'):
                    history += f"Issues identified: {attempt['feedback']}\n"
                history += "\n"

        maybe_evidence = ""
        if self.use_evidence and question.get('evidence'):
            maybe_evidence = f"\nEvidence: {question.get('evidence', 'None')}"

        prompt = f"""{history}Question: {question['question']}{maybe_evidence}

Current SQL: {sql}
Results: {summary}

Review the SQL and results. Does this correctly answer the question?

Respond with EXACTLY one of the following:
1. The single word: CORRECT
2. A new SQL query with no additional text

Do not include explanations, prefixes, or combine both responses.

Your response:"""
        return prompt

    def _get_verification_temperature(self, attempt_num):
        """Get temperature for verification attempt based on strategy."""
        if self.temperature_strategy == "progressive":
            return [0.0, 0.2, 0.3][min(attempt_num, 2)]
        elif self.temperature_strategy == "fixed":
            return 0.0
        elif self.temperature_strategy == "adaptive":
            # Could implement more complex logic later
            return 0.0 if attempt_num == 0 else 0.2
        else:
            return 0.0

    def _verify_and_improve(self, question, sql, summary, attempts):
        """
        Single API call to verify current SQL and potentially get improvement.

        Args:
            question: Question dictionary
            sql: Current SQL query
            summary: Results summary
            attempts: List of previous attempts

        Returns:
            Tuple of (is_correct: bool, new_sql: str, feedback: str)
        """
        verification_prompt = self._build_verification_prompt(question, sql, summary, attempts)

        try:
            response = self.client.messages.create(
                model=self.model_name,
                max_tokens=MAX_TOKENS,
                temperature=self._get_verification_temperature(len(attempts)),
                messages=[{
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": self.system_prompt,
                            "cache_control": {"type": "ephemeral"}
                        },
                        {
                            "type": "text",
                            "text": verification_prompt
                        }
                    ]
                }]
            )

            # Track token usage
            if hasattr(response, 'usage'):
                usage = response.usage
                input_tokens = getattr(usage, 'input_tokens', 0)
                output_tokens = getattr(usage, 'output_tokens', 0)
                cache_creation_tokens = getattr(usage, 'cache_creation_input_tokens', 0)
                cache_read_tokens = getattr(usage, 'cache_read_input_tokens', 0)

                self.total_input_tokens += input_tokens
                self.total_output_tokens += output_tokens
                self.cache_creation_tokens += cache_creation_tokens
                self.cache_read_tokens += cache_read_tokens

                # Calculate cost
                input_cost = (input_tokens / 1000000) * self.pricing['input']
                output_cost = (output_tokens / 1000000) * self.pricing['output']
                cache_creation_cost = (cache_creation_tokens / 1000000) * self.pricing['cache_write']
                cache_read_cost = (cache_read_tokens / 1000000) * self.pricing['cache_read']

                verification_cost = input_cost + output_cost + cache_creation_cost + cache_read_cost
                self.total_cost += verification_cost

            # Parse response
            response_text = response.content[0].text.strip()

            # Check for various forms of "CORRECT" (handling edge cases)
            if response_text.upper() == "CORRECT" or response_text.upper().startswith("CORRECT"):
                return True, sql, "Model confirmed SQL is correct"
            else:
                # Response should be new SQL - clean it more aggressively
                new_sql = self._clean_sql(response_text)

                # Additional cleaning for edge cases where model includes "CORRECT:" prefix
                if new_sql.upper().startswith("CORRECT"):
                    # Remove "CORRECT:" or "CORRECT " prefix
                    new_sql = new_sql[7:].strip()  # Remove "CORRECT" (7 chars)
                    if new_sql.startswith(':'):
                        new_sql = new_sql[1:].strip()  # Remove colon
                    new_sql = self._clean_sql(new_sql)  # Clean again

                return False, new_sql, "Model provided improved SQL"

        except Exception as e:
            error_str = str(e)
            print(f"    Error in verification: {error_str}")
            return False, sql, f"Verification failed: {error_str}"

    def _generate_with_verification(self, question: Dict) -> Tuple[str, Dict]:
        """
        Generate SQL with universal verification and k retries.

        Args:
            question: Question dictionary

        Returns:
            Tuple of (sql_query, verification_info) where verification_info contains
            details about all verification attempts
        """
        self.validation_stats['verification_attempted'] += 1
        attempts = []

        # Initial SQL generation (counts as an API call)
        sql = self._generate_sql_attempt(question, previous_sql=None, error_msg=None)
        self.validation_stats['generation_api_calls'] += 1
        self.validation_stats['total_api_calls'] += 1

        # Verification loop for k attempts
        for attempt_num in range(self.verification_retries):
            self.validation_stats['verification_attempts_total'] += 1

            # Execute SQL to get results for verification
            if self.sql_executor and self.db_path:
                try:
                    results = self.sql_executor.execute_sql(
                        sql=sql,
                        db_path=self.db_path,
                        is_ground_truth=False,
                        timeout_seconds=self.sql_validation_timeout
                    )
                    summary = self._summarize_results(results)
                except Exception as e:
                    summary = self._summarize_results(None, str(e))
            else:
                # If no validation available, just proceed to next iteration
                summary = "SQL validation not available (no database path provided)"

            print(f"    Verification attempt {attempt_num + 1}/{self.verification_retries}: {summary[:100]}...")

            # Verify and potentially improve (counts as an API call)
            is_correct, new_sql, feedback = self._verify_and_improve(question, sql, summary, attempts)
            self.validation_stats['verification_api_calls'] += 1
            self.validation_stats['total_api_calls'] += 1

            # Record this attempt
            attempts.append({
                'sql': sql,
                'summary': summary,
                'feedback': feedback,
                'attempt_num': attempt_num,
                'is_correct': is_correct
            })

            if is_correct:
                if attempt_num == 0:
                    # Passed on first verification attempt
                    print(f"    Verification succeeded immediately")
                    self.validation_stats['verification_passed_immediately'] += 1
                    verification_outcome = 'passed_immediately'
                else:
                    # Passed after improvement
                    print(f"    Verification succeeded after {attempt_num} improvement(s)")
                    self.validation_stats['verification_improved'] += 1
                    verification_outcome = 'improved'
                self.validation_stats['verification_succeeded'] += 1

                # Build verification info and return
                verification_info = {
                    'verification_attempts': attempt_num + 1,
                    'verification_outcome': verification_outcome,
                    'final_retry_used': False,
                    'verification_details': attempts
                }
                return sql, verification_info
            else:
                print(f"    Verification attempt {attempt_num + 1} suggested improvement: {feedback}")
                sql = new_sql

        # If all verification attempts failed, use k+1 attempt with current error/null retry logic
        print(f"    All verification attempts failed, using final error/null retry logic")
        self.validation_stats['verification_failed'] += 1

        final_retry_used = False
        final_retry_outcome = None

        # Generate final attempt using current error/null logic (if validation enabled)
        if self.sql_executor and self.db_path:
            is_valid, error_msg = self._validate_sql(sql)
            if not is_valid and error_msg and "non-retryable" not in error_msg:
                print(f"    Final validation failed: {error_msg}, attempting error retry")
                final_retry_used = True

                empty_result = False
                if error_msg == "empty":
                    error_msg = None
                    empty_result = True

                final_sql = self._generate_sql_attempt(question, previous_sql=sql, error_msg=error_msg, empty_result=empty_result)
                self.validation_stats['total_api_calls'] += 1  # Track fallback API call

                # Validate final attempt
                final_is_valid, final_error_msg = self._validate_sql(final_sql)
                if final_is_valid:
                    print(f"    Final error retry succeeded")
                    final_retry_outcome = 'succeeded'
                    sql = final_sql
                else:
                    print(f"    Final error retry failed: {final_error_msg}")
                    final_retry_outcome = 'failed'
                    sql = final_sql

        # Build verification info for failed case
        verification_info = {
            'verification_attempts': self.verification_retries,
            'verification_outcome': 'failed',
            'final_retry_used': final_retry_used,
            'final_retry_outcome': final_retry_outcome,
            'verification_details': attempts
        }
        return sql, verification_info

    def _generate_single_sql(self, question: Dict) -> Tuple[str, Dict]:
        """Generate SQL for a single question with optional verification or validation and retry.

        Returns:
            Tuple of (sql_query, verification_info)
            where verification_info contains details about verification/retry attempts
        """
        self.validation_stats['total_generated'] += 1

        # Use verification if enabled (verification_retries > 0)
        if self.verification_retries > 0:
            return self._generate_with_verification(question)

        # Otherwise, use current validation and retry logic
        # First attempt - generate SQL
        sql = self._generate_sql_attempt(question, previous_sql=None, error_msg=None)
        self.validation_stats['generation_api_calls'] += 1
        self.validation_stats['total_api_calls'] += 1

        # Initialize verification info for legacy path
        verification_info = {
            'verification_attempts': 0,
            'verification_outcome': 'no_verification',
            'legacy_retry_used': False,
            'verification_details': []
        }

        # Validate SQL if validation is enabled
        if self.sql_executor and self.db_path:
            is_valid, error_msg = self._validate_sql(sql)
            print(f"    SQL valid? {is_valid} with {error_msg}")
            if not is_valid and error_msg and "non-retryable" not in error_msg:
                # SQL validation failed with a potentially fixable error - retry once
                print(f"    SQL validation failed: {error_msg} given {sql}")
                print(f"    Retrying with error context...")
                self.validation_stats['retry_attempted'] += 1
                verification_info['legacy_retry_used'] = True
                verification_info['legacy_error'] = error_msg

                empty_result = False
                if error_msg == "empty":
                    error_msg = None
                    empty_result = True

                # Retry with both the failed SQL and error message
                retry_sql = self._generate_sql_attempt(question, previous_sql=sql, error_msg=error_msg, empty_result=empty_result)
                self.validation_stats['total_api_calls'] += 1  # Track retry API call

                # Validate the retry attempt
                retry_is_valid, retry_error_msg = self._validate_sql(retry_sql)

                if retry_is_valid:
                    print(f"    Retry succeeded - SQL is now valid, corrected to {retry_sql}")
                    self.validation_stats['retry_succeeded'] += 1
                    verification_info['legacy_retry_outcome'] = 'succeeded'
                    return retry_sql, verification_info
                else:
                    print(f"    Retry failed: {retry_error_msg} with {retry_sql}")
                    verification_info['legacy_retry_outcome'] = 'failed'
                    verification_info['legacy_retry_error'] = retry_error_msg
                    # Return the retry attempt even if it failed validation
                    return retry_sql, verification_info
            elif is_valid:
                print(f"    SQL validation passed")

        return sql, verification_info
    
    def _generate_sql_attempt(self, question: Dict, previous_sql: Optional[str] = None, error_msg: Optional[str] = None, empty_result: bool = False) -> str:
        """Generate a single SQL attempt, optionally with retry context including previous SQL and error."""
        
        maybe_evidence = ""
        if self.use_evidence and question.get('evidence'):
            maybe_evidence = f"\n\nEvidence or hints to help generate the query: {question.get('evidence', 'None provided')}"

        maybe_correct_invalid = ""
        if previous_sql and error_msg:
            maybe_correct_invalid = f"""\n\nYour previous SQL query:
```sql
{previous_sql}
```

This query failed with the following error:
{error_msg}

Please generate a corrected SQL query that fixes this error."""
            
        maybe_correct_empty = ""
        if previous_sql and empty_result and not error_msg:
            maybe_correct_empty = f"""\n\nYour previous SQL query:
```sql
{previous_sql}
```

This query produced no results, which is probably indicative of a problem. Re-examine the question, the previous sql, and 
the database profiling information above to determine if there is a better sql query that may answer the intent better."""

        question_prompt = f"""Question: {question['question']} {maybe_evidence}{maybe_correct_invalid}{maybe_correct_empty}

Return ONLY the SQL query that answers the question above using the database information above, no explanations.

SQL Query:"""
        try:
            # Use cached system prompt for efficiency
            response = self.client.messages.create(
                model=self.model_name,
                max_tokens=MAX_TOKENS,
                temperature=0,
                messages=[{
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": self.system_prompt,
                            "cache_control": {"type": "ephemeral"}
                        },
                        {
                            "type": "text",
                            "text": question_prompt
                        }
                    ]
                }]
            )
            
            # Track token usage
            if hasattr(response, 'usage'):
                usage = response.usage
                
                input_tokens = getattr(usage, 'input_tokens', 0)
                output_tokens = getattr(usage, 'output_tokens', 0)
                cache_creation_tokens = getattr(usage, 'cache_creation_input_tokens', 0)
                cache_read_tokens = getattr(usage, 'cache_read_input_tokens', 0)
                
                self.total_input_tokens += input_tokens
                self.total_output_tokens += output_tokens
                self.cache_creation_tokens += cache_creation_tokens
                self.cache_read_tokens += cache_read_tokens
                
                # Calculate cost
                input_cost = (input_tokens / 1000000) * self.pricing['input']
                output_cost = (output_tokens / 1000000) * self.pricing['output']
                cache_creation_cost = (cache_creation_tokens / 1000000) * self.pricing['cache_write']
                cache_read_cost = (cache_read_tokens / 1000000) * self.pricing['cache_read']
                
                question_cost = input_cost + output_cost + cache_creation_cost + cache_read_cost
                self.total_cost += question_cost
            
            # Extract and clean SQL
            sql = response.content[0].text.strip()
            sql = self._clean_sql(sql)
            
            return sql
            
        except Exception as e:
            error_str = str(e)
            
            # Check for specific API errors
            if "insufficient_credits" in error_str.lower() or "credit" in error_str.lower():
                print(f"    ⚠️  API CREDIT EXHAUSTION: {error_str}")
                print(f"    ⚠️  Please add credits to your API key to continue")
                return "SELECT 'API_CREDIT_EXHAUSTION' as error;"
            elif "rate_limit" in error_str.lower():
                print(f"    ⚠️  API RATE LIMIT: {error_str}")
                return "SELECT 'API_RATE_LIMIT' as error;"
            elif "authentication" in error_str.lower() or "api_key" in error_str.lower():
                print(f"    ⚠️  API AUTHENTICATION ERROR: {error_str}")
                return "SELECT 'API_AUTH_ERROR' as error;"
            else:
                print(f"    Error generating SQL: {error_str}")
                return "SELECT 'ERROR: Failed to generate SQL' as error;"
    
    def _clean_sql(self, sql: str) -> str:
        """Clean SQL output from any markdown or extra formatting."""
        # Remove markdown code blocks
        sql = re.sub(r'^```sql\s*\n', '', sql)
        sql = re.sub(r'^```\s*\n', '', sql)
        sql = re.sub(r'\n```$', '', sql)
        sql = re.sub(r'```$', '', sql)
        
        # Remove any leading/trailing whitespace
        sql = sql.strip()
        
        # Ensure it ends with semicolon
        if not sql.endswith(';'):
            sql += ';'
        
        return sql


def main():
    parser = argparse.ArgumentParser(description='Generate SQL from questions using pre-generated prompt')
    parser.add_argument('--prompt', required=True, help='System prompt file generated by subagent')
    parser.add_argument('--questions', required=True, help='Questions JSON file')
    parser.add_argument('--db_name', required=True, help='Database name')
    parser.add_argument('--output', required=True, help='Output predictions file')
    parser.add_argument('--model', choices=list(SUPPORTED_MODELS.keys()), 
                       default=DEFAULT_MODEL,
                       help=f'Model to use (default: {DEFAULT_MODEL})')
    parser.add_argument('--api_key', help='Anthropic API key')
    parser.add_argument('--limit', type=int, help='Limit number of questions to process')
    parser.add_argument('--no-evidence', action='store_true', 
                       help='Exclude evidence from prompts (for BIRD no-evidence evaluation)')
    parser.add_argument('--db_path', help='Path to database file for SQL validation (enables validation)')
    parser.add_argument('--sql_validation_timeout', type=int, default=30,
                       help='Timeout in seconds for SQL validation (default: 30)')
    parser.add_argument('--verification_retries', type=int, default=2,
                       help='Number of verification attempts (default: 2, 0 = current behavior)')
    parser.add_argument('--temperature_strategy', choices=['progressive', 'fixed', 'adaptive'],
                       default='progressive',
                       help='Temperature strategy for verification retries (default: progressive)')
    
    args = parser.parse_args()
    
    # Load questions
    with open(args.questions, 'r') as f:
        all_questions = json.load(f)
    
    # Filter questions for this database, preserving original indices
    questions = []
    for idx, q in enumerate(all_questions):
        if q.get('db_id') == args.db_name:
            # Add the original index if not present
            if 'question_id' not in q:
                q['_original_index'] = idx
            questions.append(q)
    
    if args.limit:
        questions = questions[:args.limit]
    
    print(f"\nProcessing {len(questions)} questions for database: {args.db_name}")
    print(f"Using system prompt: {args.prompt}")
    if args.db_path:
        print(f"SQL validation enabled with database: {args.db_path}")
        print(f"SQL validation timeout: {args.sql_validation_timeout} seconds")

    # Show verification settings
    if args.verification_retries > 0:
        print(f"Universal verification enabled: {args.verification_retries} retries with {args.temperature_strategy} temperature strategy")
    else:
        print("Verification disabled (using current validation/retry behavior)")
    
    # Generate SQL
    use_evidence = not args.no_evidence
    generator = PromptBasedSQLGenerator(
        args.prompt,
        args.model,
        args.api_key,
        use_evidence=use_evidence,
        db_path=args.db_path,
        sql_validation_timeout=args.sql_validation_timeout,
        verification_retries=args.verification_retries,
        temperature_strategy=args.temperature_strategy
    )
    results = generator.generate_sql(questions, args.db_name)
    
    # Add evidence mode to metadata
    results['metadata']['use_evidence'] = use_evidence
    
    # Save predictions
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print summary
    print(f"\n✓ Generated {len(results['predictions'])} SQL predictions")
    print(f"  Evidence mode: {'WITH evidence' if use_evidence else 'WITHOUT evidence'}")
    print(f"  Model: {generator.model_key} ({generator.model_name})")
    print(f"  Total input tokens: {generator.total_input_tokens:,}")
    print(f"  Total output tokens: {generator.total_output_tokens:,}")
    print(f"  Cache creation tokens: {generator.cache_creation_tokens:,}")
    print(f"  Cache read tokens: {generator.cache_read_tokens:,}")
    print(f"  Total cost: ${generator.total_cost:.4f}")
    
    if generator.cache_read_tokens > 0:
        cache_hit_ratio = (generator.cache_read_tokens / 
                          (generator.cache_creation_tokens + generator.cache_read_tokens)) * 100
        print(f"  Cache hit ratio: {cache_hit_ratio:.1f}%")
    
    cost_savings = results['metadata']['token_usage']['cost_savings']
    if cost_savings['percentage'] > 0:
        print(f"  Cost savings from caching: ${cost_savings['amount']:.4f} ({cost_savings['percentage']:.1f}%)")
    
    # Print SQL validation statistics if validation was enabled
    if args.db_path and 'validation_stats' in results['metadata']:
        stats = results['metadata']['validation_stats']
        print(f"\n✓ SQL Validation Statistics:")
        print(f"  Total queries generated: {stats['total_generated']}")

        # Enhanced verification statistics if enabled
        if stats.get('verification_attempted', 0) > 0:
            total_api = stats.get('total_api_calls', 0)
            total_q = stats['total_generated']
            pass_immediate = stats.get('verification_passed_immediately', 0)
            improved = stats.get('verification_improved', 0)
            failed = stats.get('verification_failed', 0)
            verify_calls = stats.get('verification_api_calls', 0)

            api_ratio = f"{total_api/total_q:.2f}x" if total_q > 0 else "N/A"

            print(f"  --- Universal Verification Enabled ---")
            print(f"  API Calls: {total_api} total ({api_ratio} overhead)")
            print(f"  • Generation: {stats.get('generation_api_calls', 0)} calls")
            print(f"  • Verification: {verify_calls} calls")
            print(f"  Verification Outcomes:")
            print(f"  • Passed immediately: {pass_immediate} ({pass_immediate/total_q*100:.1f}%)")
            print(f"  • Improved via retry: {improved} ({improved/total_q*100:.1f}%)")
            print(f"  • Failed all attempts: {failed} ({failed/total_q*100:.1f}%)")

        # Legacy validation stats (always shown if validation is enabled)
        if stats['validation_attempted'] > 0 or stats['retry_attempted'] > 0:
            print(f"  --- Legacy Syntax Validation ---")
            print(f"  Validation attempted: {stats['validation_attempted']}")
            print(f"  Validation passed: {stats['validation_passed']}")
            print(f"  Validation failed: {stats['validation_failed']}")
            if stats['retry_attempted'] > 0:
                print(f"  Retries attempted: {stats['retry_attempted']}")
                print(f"  Retries succeeded: {stats['retry_succeeded']}")
                retry_success_rate = (stats['retry_succeeded'] / stats['retry_attempted']) * 100 if stats['retry_attempted'] > 0 else 0
                print(f"  Retry success rate: {retry_success_rate:.1f}%")

        if stats.get('timeout_errors', 0) > 0:
            print(f"  Timeout errors: {stats['timeout_errors']}")
    
    print(f"  Output saved to: {args.output}")


if __name__ == "__main__":
    main()