#!/usr/bin/env python3
"""
Stage 1: Universal SQL Data Retrieval for Optimization Problems
This script is baseline-agnostic and generates enhanced problem descriptions
"""

import os
import sys
import json
import time
import argparse
import sqlite3
import pandas as pd
from pathlib import Path
import re
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import threading

# RITS API setup
import openai
from openai import OpenAI

BASE_URL = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/deepseek-v3-h200/v1"
MODEL_NAME = "deepseek-ai/DeepSeek-V3"

def setup_rits_api_client():
    """Setup client using RITS API configuration"""
    # Set the RITS API key here
    os.environ['RITS_API_KEY'] = ''
    api_key = os.environ.get("RITS_API_KEY")
    
    if not api_key:
        raise ValueError("Please set RITS_API_KEY environment variable")
    
    client = OpenAI(
        api_key="dummy",
        base_url=BASE_URL,
        default_headers={"RITS_API_KEY": api_key},
        timeout=300
    )
    return client

rits_client = setup_rits_api_client()

def get_response(prompt, max_retries=3):
    """Call LLM API with retry logic"""
    for attempt in range(max_retries):
        try:
            response = rits_client.chat.completions.create(
                model=MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=4096,
                temperature=0.1,
                top_p=0.9,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                stream=False
            )
            
            if response.choices and response.choices[0].message.content:
                return response.choices[0].message.content
            else:
                print(f"WARNING: Empty response on attempt {attempt + 1}")
                
        except Exception as e:
            print(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                print(f"Retrying in {2 ** attempt} seconds...")
                time.sleep(2 ** attempt)
            else:
                raise e
    
    return "No response generated after multiple attempts."

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

def find_problem_descriptions_with_schema(syn_data_dir):
    """Find all problem_description.md files with corresponding schema data"""
    problem_files = []
    
    for schema_dir in Path(syn_data_dir).iterdir():
        if schema_dir.is_dir():
            problem_desc_path = schema_dir / "problem_description.md"
            schema_cache_dir = schema_dir / "schema_cache" / "latest"
            schema_sql_path = schema_cache_dir / "schema.sql"
            data_sql_path = schema_cache_dir / "data.sql"
            
            if (problem_desc_path.exists() and 
                schema_sql_path.exists() and 
                data_sql_path.exists()):
                
                problem_files.append({
                    'schema_name': schema_dir.name,
                    'problem_path': str(problem_desc_path),
                    'schema_dir': str(schema_dir),
                    'schema_sql_path': str(schema_sql_path),
                    'data_sql_path': str(data_sql_path)
                })
                print(f"Found: {schema_dir.name} (with schema data)")
            else:
                missing = []
                if not problem_desc_path.exists():
                    missing.append("problem_description.md")
                if not schema_sql_path.exists():
                    missing.append("schema.sql")
                if not data_sql_path.exists():
                    missing.append("data.sql")
                print(f"Skip: {schema_dir.name} (missing: {', '.join(missing)})")
    
    return problem_files

def extract_problem_context_without_stored_values(problem_description_path):
    """Extract problem description without 'Current Stored Values' section"""
    
    with open(problem_description_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Find and remove "Current Stored Values" section
    lines = content.split('\n')
    filtered_lines = []
    skip_section = False
    
    for line in lines:
        # Check if we're entering the "Current Stored Values" section
        if '### Current Stored Values' in line or '### current stored values' in line.lower():
            skip_section = True
            continue
        
        # Check if we're entering a new section (stop skipping)
        if skip_section and line.startswith('### ') and 'current stored values' not in line.lower():
            skip_section = False
        
        # Add line if we're not skipping
        if not skip_section:
            filtered_lines.append(line)
    
    return '\n'.join(filtered_lines)

def create_database_from_schema(schema_sql_path, data_sql_path):
    """Create in-memory SQLite database from schema and data files"""
    
    # Read schema and data SQL
    with open(schema_sql_path, 'r', encoding='utf-8') as f:
        schema_sql = f.read()
    
    with open(data_sql_path, 'r', encoding='utf-8') as f:
        data_sql = f.read()
    
    # Create in-memory database
    conn = sqlite3.connect(':memory:')
    cursor = conn.cursor()
    
    try:
        # Execute schema creation
        schema_statements = [stmt.strip() for stmt in schema_sql.split(';') if stmt.strip()]
        for stmt in schema_statements:
            if stmt.strip():
                cursor.execute(stmt)
        
        # Execute data insertion
        data_statements = [stmt.strip() for stmt in data_sql.split(';') if stmt.strip()]
        for stmt in data_statements:
            if stmt.strip():
                cursor.execute(stmt)
        
        conn.commit()
        return conn
        
    except Exception as e:
        print(f"    ERROR creating database: {e}")
        conn.close()
        return None

def generate_sql_queries_with_llm(problem_context, schema_sql):
    """Use LLM to generate SQL queries based on problem context and schema"""
    
    prompt = f"""You are an expert database analyst helping with optimization problem data retrieval. Based on the problem description and database schema, analyze what data would be most useful for solving this optimization problem and generate appropriate SQL SELECT queries.

**Problem Context (without stored values):**
{problem_context}

**Database Schema:**
{schema_sql}

**Your Task:** 
Carefully analyze the optimization problem and determine what data from the database would be most relevant. Then generate SQL SELECT queries to retrieve this data.

**Analysis Guidelines:**
- Identify what data is needed for decision variables (what needs to be optimized)
- Identify what data is needed for objective function coefficients (what to maximize/minimize)
- Identify what data is needed for constraint parameters (limitations and requirements)
- Consider what aggregated data, summary statistics, or lookup information might be helpful
- Think about relationships between tables and potential joins
- Consider filtering criteria that would make the data more relevant

**Query Requirements:**
- Generate as many queries as you think are necessary and useful (no fixed number)
- Each query should have a clear purpose for the optimization problem
- Include meaningful comments explaining what each query retrieves and why it's relevant
- Use proper SQL syntax with table names, column names, and joins as needed
- Consider both detailed data and summary/aggregated data where appropriate

**Output Format:**
```sql
-- Query Description: Explain what this retrieves and why it's important for optimization
SELECT ... FROM ... WHERE ... ;

-- Query Description: Explain what this retrieves and why it's important for optimization  
SELECT ... FROM ... WHERE ... ;

-- Continue with additional queries as needed
```

Analyze the problem and generate the most relevant SQL queries:"""

    try:
        response = get_response(prompt)
        return response
    except Exception as e:
        print(f"    ERROR generating SQL queries: {e}")
        return None

def extract_sql_queries_from_response(llm_response):
    """Extract individual SQL queries from LLM response with flexible quantity"""
    
    queries = []
    
    # Find all SQL code blocks
    sql_blocks = re.findall(r'```sql\s*(.*?)\s*```', llm_response, re.DOTALL)
    
    if sql_blocks:
        # Process the main SQL block
        sql_content = sql_blocks[0]
    else:
        # Fallback: look for SELECT statements directly
        sql_content = llm_response
    
    # Split by comments and extract queries
    lines = sql_content.split('\n')
    current_query = []
    current_comment = ""
    
    for line in lines:
        line = line.strip()
        
        if line.startswith('--'):
            # Save previous query if exists
            if current_query:
                query_text = ' '.join(current_query).strip()
                if query_text.upper().startswith('SELECT') and query_text.endswith(';'):
                    queries.append({
                        'comment': current_comment,
                        'query': query_text
                    })
                current_query = []
            
            # Start new comment
            current_comment = line[2:].strip()
        
        elif line.upper().startswith('SELECT'):
            current_query = [line]
        
        elif current_query and line:
            current_query.append(line)
            
            # Check if query is complete
            if line.endswith(';'):
                query_text = ' '.join(current_query).strip()
                queries.append({
                    'comment': current_comment,
                    'query': query_text
                })
                current_query = []
    
    # Handle any remaining query without semicolon
    if current_query:
        query_text = ' '.join(current_query).strip()
        if query_text.upper().startswith('SELECT'):
            # Add semicolon if missing
            if not query_text.endswith(';'):
                query_text += ';'
            queries.append({
                'comment': current_comment,
                'query': query_text
            })
    
    print(f"    Extracted {len(queries)} queries from LLM response")
    return queries

def execute_sql_queries(conn, queries):
    """Execute SQL queries and return results as DataFrames"""
    
    results = []
    
    for i, query_info in enumerate(queries):
        comment = query_info['comment']
        query = query_info['query']
        
        try:
            # Execute query
            df = pd.read_sql_query(query, conn)
            
            results.append({
                'comment': comment,
                'query': query,
                'result_df': df,
                'result_csv': df.to_csv(index=False)
            })
            
            print(f"    Query {i+1}: {comment} -> {len(df)} rows")
            
        except Exception as e:
            print(f"    ERROR executing query {i+1}: {e}")
            print(f"    Query: {query}")
            
            # Add empty result
            results.append({
                'comment': comment,
                'query': query,
                'result_df': pd.DataFrame(),
                'result_csv': "",
                'error': str(e)
            })
    
    return results

def create_enhanced_problem_description(original_problem_path, query_results):
    """Create enhanced problem description with retrieved data"""
    
    with open(original_problem_path, 'r', encoding='utf-8') as f:
        original_content = f.read()
    
    # Remove existing "Current Stored Values" section if present
    enhanced_content = extract_problem_context_without_stored_values(original_problem_path)
    
    # Add new "Retrieved Values" section
    enhanced_content += "\n\n### Retrieved Values\n\n"
    
    for i, result in enumerate(query_results):
        comment = result['comment']
        query = result['query']
        csv_data = result['result_csv']
        
        enhanced_content += f"**Query {i+1}: {comment}**\n\n"
        enhanced_content += f"```sql\n{query}\n```\n\n"
        
        if csv_data and not result.get('error'):
            enhanced_content += f"**Results (CSV format):**\n```csv\n{csv_data}```\n\n"
        else:
            error_msg = result.get('error', 'No data returned')
            enhanced_content += f"**Error:** {error_msg}\n\n"
    
    return enhanced_content

def process_single_database(problem_info_and_output_dir):
    """Worker function for processing a single database"""
    problem_info, output_base_dir = problem_info_and_output_dir
    
    schema_name = problem_info['schema_name']
    problem_path = problem_info['problem_path']
    schema_sql_path = problem_info['schema_sql_path']
    data_sql_path = problem_info['data_sql_path']
    
    print(f"\n=== Processing: {schema_name} ===")
    
    # Create output directory
    output_dir = Path(output_base_dir) / schema_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    try:
        # Check if enhanced problem description already exists
        enhanced_file = output_dir / "enhanced_problem_description.md"
        if enhanced_file.exists():
            print(f"  Enhanced problem description already exists for {schema_name}, skipping...")
            return {
                "schema_name": schema_name,
                "status": "skipped",
                "reason": "enhanced_problem_description.md already exists",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
        
        # Extract problem context without stored values
        problem_context = extract_problem_context_without_stored_values(problem_path)
        
        # Read schema
        with open(schema_sql_path, 'r', encoding='utf-8') as f:
            schema_sql = f.read()
        
        # Generate SQL queries using LLM
        print(f"    Generating SQL queries with LLM for {schema_name}...")
        llm_response = generate_sql_queries_with_llm(problem_context, schema_sql)
        
        if not llm_response:
            raise Exception("Failed to generate SQL queries with LLM")
        
        # Save LLM response for debugging
        with open(output_dir / "stage1_llm_sql_generation.txt", "w", encoding="utf-8") as f:
            f.write(f"=== Problem Context ===\n{problem_context}\n\n")
            f.write(f"=== Schema ===\n{schema_sql}\n\n")
            f.write(f"=== LLM Response ===\n{llm_response}\n")
        
        # Extract queries from LLM response
        queries = extract_sql_queries_from_response(llm_response)
        print(f"    Generated {len(queries)} SQL queries for {schema_name}")
        
        if not queries:
            raise Exception("No valid SQL queries extracted from LLM response")
        
        # Create database
        print(f"    Creating in-memory database for {schema_name}...")
        conn = create_database_from_schema(schema_sql_path, data_sql_path)
        
        if not conn:
            raise Exception("Failed to create database")
        
        # Execute queries
        print(f"    Executing SQL queries for {schema_name}...")
        query_results = execute_sql_queries(conn, queries)
        conn.close()
        
        # Create enhanced problem description
        enhanced_problem_text = create_enhanced_problem_description(problem_path, query_results)
        
        # Save enhanced problem description
        with open(enhanced_file, "w", encoding="utf-8") as f:
            f.write(enhanced_problem_text)
        
        # Save query results for debugging
        stage1_results = {
            "schema_name": schema_name,
            "queries_generated": len(queries),
            "queries_executed": len(query_results),
            "query_results": [
                {
                    "comment": r['comment'],
                    "query": r['query'],
                    "rows_returned": len(r['result_df']) if 'result_df' in r else 0,
                    "error": r.get('error')
                }
                for r in query_results
            ]
        }
        
        with open(output_dir / "stage1_results.json", "w", encoding="utf-8") as f:
            json.dump(stage1_results, f, indent=2)
        
        print(f"  SUCCESS: {schema_name} - {len(query_results)} queries processed")
        
        return {
            "schema_name": schema_name,
            "status": "success",
            "queries_generated": len(queries),
            "queries_executed": len(query_results),
            "enhanced_file_path": str(enhanced_file),
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
    except Exception as e:
        print(f"  ERROR: {schema_name} failed: {e}")
        
        error_summary = {
            "schema_name": schema_name,
            "status": "failed",
            "error": str(e),
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        with open(output_dir / "stage1_error.json", "w") as f:
            json.dump(error_summary, f, indent=2)
        
        return error_summary

def main():
    parser = argparse.ArgumentParser(description="Stage 1: Universal SQL Data Retrieval for Optimization Problems")
    parser.add_argument("--syn_data_dir", type=str, required=True,
                       help="Path to syn_data_gen dataset directory")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for enhanced problem descriptions")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--max_workers", type=int, default=None,
                       help="Maximum number of parallel workers (default: CPU count)")
    
    args = parser.parse_args()
    
    # Set default number of workers
    if args.max_workers is None:
        args.max_workers = min(mp.cpu_count(), 8)  # Cap at 8 to avoid API rate limits
    
    print("Stage 1: Universal SQL Data Retrieval")
    print("=" * 50)
    print(f"Input directory: {args.syn_data_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Max parallel workers: {args.max_workers}")
    print(f"Using RITS API: llama-3-3-70b-instruct")
    
    # Find problems with schema data
    problem_files = find_problem_descriptions_with_schema(args.syn_data_dir)
    
    if not problem_files:
        print("ERROR: No problems with complete schema data found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} problems with schema data")
    
    # Limit for testing
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to first {args.max_problems} problems for testing")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare arguments for parallel processing
    worker_args = [(problem_info, args.output_dir) for problem_info in problem_files]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems in parallel
    results = []
    start_time = time.time()
    
    print(f"\nStarting parallel processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        # Submit all tasks
        future_to_problem = {
            executor.submit(process_single_database, worker_arg): worker_arg[0]['schema_name'] 
            for worker_arg in worker_args
        }
        
        # Collect results as they complete
        for future in as_completed(future_to_problem):
            schema_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                progress.update(success=(result["status"] in ["success", "skipped"]))
                
            except Exception as exc:
                print(f'\nERROR: {schema_name} generated an exception: {exc}')
                error_result = {
                    "schema_name": schema_name,
                    "status": "failed",
                    "error": str(exc),
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
                results.append(error_result)
                progress.update(success=False)
    
    # Generate overall summary
    total_time = time.time() - start_time
    successful = [r for r in results if r["status"] == "success"]
    skipped = [r for r in results if r["status"] == "skipped"]
    failed = [r for r in results if r["status"] == "failed"]
    
    overall_summary = {
        "run_info": {
            "stage": "stage1_sql_retrieval",
            "total_problems": len(problem_files),
            "successful": len(successful),
            "skipped": len(skipped),
            "failed": len(failed),
            "success_rate": f"{(len(successful) + len(skipped))/len(problem_files)*100:.1f}%",
            "total_time": f"{total_time:.1f} seconds",
            "average_time_per_problem": f"{total_time/len(problem_files):.1f} seconds",
            "parallel_workers": args.max_workers,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        },
        "configuration": {
            "model": "deepseek-ai/DeepSeek-V3 (RITS API)",
            "database": "In-memory SQLite from schema.sql and data.sql",
            "max_workers": args.max_workers,
            "input_dir": args.syn_data_dir,
            "output_dir": args.output_dir
        },
        "results": results
    }
    
    with open(os.path.join(args.output_dir, "stage1_summary.json"), "w") as f:
        json.dump(overall_summary, f, indent=2)
    
    # Print final summary
    print("\n" + "=" * 50)
    print("Stage 1 Complete!")
    print(f"Total problems processed: {len(problem_files)}")
    print(f"Successful: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Skipped (already exist): {len(skipped)} ({len(skipped)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Enhanced problem descriptions saved to: {args.output_dir}")

if __name__ == "__main__":
    main()