# parallel_runner.py
import os
import sys
import json
import argparse
import multiprocessing as mp
import subprocess
import time
from pathlib import Path

def load_valid_databases(spider_dir="./spider"):
    """Load and filter valid databases (<=5 tables) from Spider dataset"""
    tables_path = Path(spider_dir) / "tables.json"
    
    if not tables_path.exists():
        # Try to find tables.json in subdirectories
        for root, dirs, files in os.walk(spider_dir):
            if "tables.json" in files:
                tables_path = Path(root) / "tables.json"
                break
        else:
            raise FileNotFoundError(f"tables.json not found in {spider_dir}")
    
    with open(tables_path, "r", encoding="utf-8") as f:
        tables_data = json.load(f)
    
    # Filter databases with <= 5 tables
    valid_dbs = []
    for db in tables_data:
        table_count = len(db.get('table_names_original', []))
        if table_count <= 5:
            valid_dbs.append(db['db_id'])
    
    return valid_dbs

def run_single_database(db_name, spider_dir="./spider", output_dir="./text2opt_dataset_alternating_optimization"):
    """Run optimization problem generation for a single database"""
    cmd = [
        sys.executable, "-m", "main",  # Run as module
        "--database", db_name,
        "--spider-dir", spider_dir,
        "--output-dir", output_dir
    ]
    
    start_time = time.time()
    try:
        print(f"[START] Processing {db_name}...")
        result = subprocess.run(
            cmd, 
            capture_output=True, 
            text=True, 
            timeout=3600  # 1 hour timeout per database
        )
        
        execution_time = time.time() - start_time
        
        if result.returncode == 0:
            print(f"[SUCCESS] {db_name} completed in {execution_time:.1f}s")
            return {
                "database": db_name,
                "status": "success",
                "execution_time": execution_time,
                "stdout_lines": len(result.stdout.split('\n')) if result.stdout else 0
            }
        else:
            print(f"[FAILED] {db_name} failed after {execution_time:.1f}s")
            # Extract error information
            error_msg = result.stderr[:500] if result.stderr else "Unknown error"
            if not error_msg and result.stdout:
                # Sometimes errors are printed to stdout
                stdout_lines = result.stdout.split('\n')
                error_lines = [line for line in stdout_lines if 'ERROR' in line or 'Error' in line]
                error_msg = '\n'.join(error_lines[-5:]) if error_lines else "Check output logs"
            
            return {
                "database": db_name,
                "status": "failed",
                "execution_time": execution_time,
                "error": error_msg,
                "return_code": result.returncode
            }
            
    except subprocess.TimeoutExpired:
        print(f"[TIMEOUT] {db_name} timed out after 1 hour")
        return {
            "database": db_name,
            "status": "timeout",
            "execution_time": 3600
        }
    except Exception as e:
        print(f"[ERROR] {db_name} encountered error: {str(e)}")
        return {
            "database": db_name,
            "status": "error",
            "error": str(e)
        }

def save_results_summary(results, output_dir):
    """Save execution results summary to JSON file"""
    summary = {
        "total_databases": len(results),
        "successful": len([r for r in results if r['status'] == 'success']),
        "failed": len([r for r in results if r['status'] == 'failed']),
        "timeout": len([r for r in results if r['status'] == 'timeout']),
        "error": len([r for r in results if r['status'] == 'error']),
        "total_execution_time": sum(r.get('execution_time', 0) for r in results),
        "results": results
    }
    
    output_file = os.path.join(output_dir, "parallel_execution_summary.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nResults summary saved to {output_file}")
    return summary

def check_completed_databases(output_dir):
    """Check which databases have already been processed"""
    completed = set()
    if os.path.exists(output_dir):
        for item in os.listdir(output_dir):
            item_path = os.path.join(output_dir, item)
            if os.path.isdir(item_path) and item != "logs":
                # Check if the database has the expected output files
                expected_files = ["problem_solution_description.md", "or_analysis.json"]
                if all(os.path.exists(os.path.join(item_path, f)) for f in expected_files):
                    completed.add(item)
    return completed

def main():
    parser = argparse.ArgumentParser(
        description="Parallel runner for schema2optsgd optimization problem generation"
    )
    parser.add_argument(
        "--database", 
        help="Process only this specific database"
    )
    parser.add_argument(
        "--parallel", 
        type=int, 
        default=4, 
        help="Number of parallel processes (default: 4)"
    )
    parser.add_argument(
        "--spider-dir", 
        default="../spider",
        help="Path to Spider dataset directory (default: ./spider)"
    )
    parser.add_argument(
        "--output-dir", 
        default="./text2opt_dataset_alternating_optimization",
        help="Output directory for parallel processing (default: ./text2opt_dataset_alternating_optimization)"
    )
    parser.add_argument(
        "--list-databases", 
        action="store_true",
        help="List all valid databases and exit"
    )
    parser.add_argument(
        "--resume", 
        action="store_true",
        help="Resume processing, skip already completed databases"
    )
    parser.add_argument(
        "--dry-run", 
        action="store_true",
        help="Show what would be processed without actually running"
    )
    
    args = parser.parse_args()
    
    # Check if main module is accessible
    try:
        import main
    except ImportError:
        print("Error: schema2optsgd module not found. Make sure it's in your Python path.")
        print("You may need to run: export PYTHONPATH=$PYTHONPATH:$(pwd)")
        sys.exit(1)
    
    try:
        valid_dbs = load_valid_databases(args.spider_dir)
        print(f"Found {len(valid_dbs)} valid databases (<=5 tables)")
        
        if args.list_databases:
            print("\nValid databases:")
            for i, db in enumerate(sorted(valid_dbs), 1):
                print(f"  {i:3d}. {db}")
            return
            
    except FileNotFoundError as e:
        print(f"Error: {e}")
        sys.exit(1)
    
    # Check for already completed databases if resume flag is set
    to_process = valid_dbs
    if args.resume:
        completed = check_completed_databases(args.output_dir)
        to_process = [db for db in valid_dbs if db not in completed]
        if completed:
            print(f"\nResuming: Found {len(completed)} already completed databases")
            print(f"Remaining: {len(to_process)} databases to process")
    
    if args.dry_run:
        print(f"\nDry run - would process {len(to_process)} databases:")
        for db in sorted(to_process):
            print(f"  - {db}")
        return
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    if args.database:
        # Single database processing
        if args.database not in valid_dbs:
            print(f"Error: Database '{args.database}' not found in valid databases")
            print(f"Use --list-databases to see all valid databases")
            sys.exit(1)
        
        print(f"\nProcessing single database: {args.database}")
        result = run_single_database(args.database, args.spider_dir, args.output_dir)
        
        # Save single result
        save_results_summary([result], args.output_dir)
        
        print(f"\nSingle database result:")
        print(f"  Database: {result['database']}")
        print(f"  Status: {result['status']}")
        if 'execution_time' in result:
            print(f"  Time: {result['execution_time']:.1f}s")
        if result['status'] != 'success' and 'error' in result:
            print(f"  Error: {result['error']}")
        
    else:
        # Parallel processing of databases
        if not to_process:
            print("\nNo databases to process. All databases have been completed.")
            return
            
        print(f"\nStarting parallel processing with {args.parallel} processes...")
        print(f"Output directory: {args.output_dir}")
        print(f"Databases to process: {len(to_process)}")
        
        start_time = time.time()
        
        # Create partial function with fixed arguments
        from functools import partial
        worker_func = partial(run_single_database, 
                            spider_dir=args.spider_dir, 
                            output_dir=args.output_dir)
        
        # Process in parallel
        with mp.Pool(args.parallel) as pool:
            results = pool.map(worker_func, to_process)
        
        total_time = time.time() - start_time
        
        # Save and display results
        summary = save_results_summary(results, args.output_dir)
        
        print(f"\n{'='*60}")
        print(f"PARALLEL EXECUTION SUMMARY")
        print(f"{'='*60}")
        print(f"Total databases processed: {summary['total_databases']}")
        print(f"Successful: {summary['successful']}")
        print(f"Failed: {summary['failed']}")
        print(f"Timeout: {summary['timeout']}")
        print(f"Error: {summary['error']}")
        if summary['total_databases'] > 0:
            print(f"Success rate: {summary['successful']/summary['total_databases']*100:.1f}%")
        print(f"Total execution time: {total_time:.1f}s")
        if summary['total_databases'] > 0:
            print(f"Average time per database: {total_time/summary['total_databases']:.1f}s")
        print(f"Output directory: {args.output_dir}")
        
        # Show failed databases if any
        failed_dbs = [r['database'] for r in results if r['status'] != 'success']
        if failed_dbs:
            print(f"\nFailed databases ({len(failed_dbs)}):")
            for db in failed_dbs:
                result = next(r for r in results if r['database'] == db)
                print(f"  - {db}: {result['status']}")
                if 'error' in result:
                    print(f"    Error: {result['error'][:100]}...")
        
        # Save detailed failure log
        if failed_dbs:
            failure_log = os.path.join(args.output_dir, "failed_databases.json")
            failed_results = [r for r in results if r['status'] != 'success']
            with open(failure_log, "w", encoding="utf-8") as f:
                json.dump(failed_results, f, indent=2)
            print(f"\nDetailed failure log saved to: {failure_log}")

if __name__ == "__main__":
    main()