#!/usr/bin/env python3
"""
Stage 2: OptiMUS Processing with Enhanced Problem Descriptions
Reads enhanced problem descriptions and solves them using OptiMUS pipeline
"""

import os
import sys
import json
import time
import argparse
import shutil
from pathlib import Path
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import threading

# Add OptiMUS modules to path (assuming OptiMUS is in parent directory or accessible)
sys.path.append(".")
sys.path.append("..")

# Import OptiMUS modules
from parameters import get_params
from constraint import get_constraints
from constraint_model import get_constraint_formulations
from target_code import get_codes
from generate_code import generate_code
from utils import load_state, save_state, Logger, get_labels
from objective import get_objective
from objective_model import get_objective_formulation
from execute_code import execute_and_debug
from rag.rag_utils import RAGMode

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

def find_enhanced_problem_files(enhanced_problems_dir):
    """Find all enhanced problem description files"""
    problem_files = []
    
    for problem_dir in Path(enhanced_problems_dir).iterdir():
        if problem_dir.is_dir():
            enhanced_file = problem_dir / "enhanced_problem_description.md"
            if enhanced_file.exists():
                problem_files.append({
                    'database_name': problem_dir.name,
                    'enhanced_file_path': str(enhanced_file),
                    'source_dir': str(problem_dir)
                })
                print(f"Found: {problem_dir.name}")
            else:
                print(f"Skip: {problem_dir.name} (no enhanced_problem_description.md)")
    
    return problem_files

def create_state_from_enhanced_problem_description(enhanced_problem_text, run_dir):
    """Create state from enhanced problem description text"""
    
    # Extract parameters using OptiMUS parameter extraction WITHOUT user interaction
    params = get_params(enhanced_problem_text, check=False)
    
    # Create empty data.json (OptiMUS requirement)
    data = {}
    with open(os.path.join(run_dir, "data.json"), "w") as f:
        json.dump(data, f, indent=4)
    
    state = {"description": enhanced_problem_text, "parameters": params}
    return state

def process_single_problem_optimus(problem_info_and_config):
    """Worker function for processing a single problem with OptiMUS"""
    problem_info, output_base_dir, model_name, error_correction = problem_info_and_config
    
    database_name = problem_info['database_name']
    enhanced_file_path = problem_info['enhanced_file_path']
    source_dir = problem_info['source_dir']
    
    print(f"\n=== OptiMUS Processing: {database_name} ===")
    
    # Create output directory
    output_dir = Path(output_base_dir) / database_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Setup logger
    logger = Logger(str(output_dir / "optimus_log.txt"))
    logger.reset()
    
    try:
        # Check if result already exists
        code_output_file = output_dir / "code_output.txt"
        if code_output_file.exists():
            print(f"  Result already exists for {database_name}, skipping...")
            return {
                "database_name": database_name,
                "status": "skipped",
                "reason": "code_output.txt already exists",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
        
        # Read enhanced problem description
        with open(enhanced_file_path, 'r', encoding='utf-8') as f:
            enhanced_problem_text = f.read()
        
        print(f"  Loaded enhanced problem description for {database_name} ({len(enhanced_problem_text)} chars)")
        
        # === OptiMUS Pipeline ===
        start_time = time.time()
        
        # Configuration - matching original OptiMUS setup
        RAG_MODE = None
        DEFAULT_LABELS = {"types": ["Mathematical Optimization"], "domains": ["Operations Management"]}
        
        print(f"  Starting OptiMUS pipeline for {database_name}...")
        
        # Step 1: Create initial state from enhanced problem description
        state = create_state_from_enhanced_problem_description(enhanced_problem_text, str(output_dir))
        save_state(state, str(output_dir / "state_1_params.json"))
        
        # Step 2: Extract objective
        print(f"    Extracting objective for {database_name}...")
        state = load_state(str(output_dir / "state_1_params.json"))
        objective = get_objective(
            state["description"],
            state["parameters"],
            check=error_correction,
            logger=logger,
            model=model_name,
            rag_mode=RAG_MODE,
            labels=DEFAULT_LABELS,
        )
        state["objective"] = objective
        save_state(state, str(output_dir / "state_2_objective.json"))
        
        # Step 3: Extract constraints
        print(f"    Extracting constraints for {database_name}...")
        constraints = get_constraints(
            state["description"],
            state["parameters"],
            check=error_correction,
            logger=logger,
            model=model_name,
            rag_mode=RAG_MODE,
            labels=DEFAULT_LABELS,
        )
        state["constraints"] = constraints
        save_state(state, str(output_dir / "state_3_constraints.json"))
        
        # Step 4: Formulate constraints
        print(f"    Formulating constraints for {database_name}...")
        constraints, variables = get_constraint_formulations(
            state["description"],
            state["parameters"],
            state["constraints"],
            check=error_correction,
            logger=logger,
            model=model_name,
            rag_mode=RAG_MODE,
            labels=DEFAULT_LABELS,
        )
        state["constraints"] = constraints
        state["variables"] = variables
        save_state(state, str(output_dir / "state_4_constraints_modeled.json"))
        
        # Step 5: Formulate objective
        print(f"    Formulating objective for {database_name}...")
        objective = get_objective_formulation(
            state["description"],
            state["parameters"],
            state["variables"],
            state["objective"],
            model=model_name,
            check=error_correction,
            rag_mode=RAG_MODE,
            labels=DEFAULT_LABELS,
        )
        state["objective"] = objective
        save_state(state, str(output_dir / "state_5_objective_modeled.json"))
        
        # Step 6: Generate code
        print(f"    Generating code for {database_name}...")
        constraints, objective = get_codes(
            state["description"],
            state["parameters"],
            state["variables"],
            state["constraints"],
            state["objective"],
            model=model_name,
            check=error_correction,
        )
        state["constraints"] = constraints
        state["objective"] = objective
        save_state(state, str(output_dir / "state_6_code.json"))
        
        # Step 7: Execute code
        print(f"    Executing code for {database_name}...")
        generate_code(state, str(output_dir))
        execute_and_debug(state, model=model_name, dir=str(output_dir), logger=logger)
        
        processing_time = time.time() - start_time
        
        # Copy enhanced problem description to output for reference
        shutil.copy2(enhanced_file_path, str(output_dir / "enhanced_problem_description.md"))
        
        # Check if execution was successful by looking for code_output.txt
        execution_success = (output_dir / "code_output.txt").exists()
        
        # Save results summary
        stage2_results = {
            "database_name": database_name,
            "status": "success" if execution_success else "failed",
            "execution_success": execution_success,
            "processing_time": processing_time,
            "model_used": model_name,
            "error_correction": error_correction,
            "enhanced_file_source": enhanced_file_path,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "final_state": {
                "parameters_count": len(state.get("parameters", {})),
                "variables_count": len(state.get("variables", {})),
                "constraints_count": len(state.get("constraints", [])),
                "has_objective": bool(state.get("objective", {}).get("formulation"))
            }
        }
        
        with open(output_dir / "stage2_optimus_results.json", "w", encoding="utf-8") as f:
            json.dump(stage2_results, f, indent=2)
        
        print(f"  {'SUCCESS' if execution_success else 'FAILED'}: {database_name}")
        
        return stage2_results
        
    except Exception as e:
        print(f"  ERROR: {database_name} failed: {e}")
        
        error_summary = {
            "database_name": database_name,
            "status": "failed",
            "error": str(e),
            "model_used": model_name,
            "error_correction": error_correction,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        with open(output_dir / "stage2_optimus_error.json", "w") as f:
            json.dump(error_summary, f, indent=2)
        
        # Create empty code_output.txt to indicate failure for evaluation scripts
        with open(output_dir / "code_output.txt", "w") as f:
            f.write(f"ERROR: OptiMUS processing failed - {str(e)}\n")
        
        return error_summary

def main():
    parser = argparse.ArgumentParser(description="Stage 2: OptiMUS Processing with Enhanced Problems")
    parser.add_argument("--enhanced_problems_dir", type=str, required=True,
                       help="Directory containing enhanced problem descriptions from Stage 1")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for OptiMUS results")
    parser.add_argument("--model", type=str, default="deepseek-ai/DeepSeek-V3",
                       help="Model name to use for LLM queries")
    parser.add_argument("--error_correction", action='store_true',
                       help="Enable error correction in OptiMUS pipeline")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--max_workers", type=int, default=None,
                       help="Maximum number of parallel workers")
    
    args = parser.parse_args()
    
    # Set default number of workers
    if args.max_workers is None:
        args.max_workers = min(mp.cpu_count(), 8)  # Conservative for OptiMUS
    
    print("Stage 2: OptiMUS Processing")
    print("=" * 50)
    print(f"Enhanced problems directory: {args.enhanced_problems_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {args.model}")
    print(f"Error correction: {args.error_correction}")
    print(f"Max parallel workers: {args.max_workers}")
    
    # Find enhanced problem files
    problem_files = find_enhanced_problem_files(args.enhanced_problems_dir)
    
    if not problem_files:
        print("ERROR: No enhanced problem descriptions found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} enhanced problem descriptions")
    
    # Limit for testing
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to first {args.max_problems} problems for testing")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare arguments for parallel processing
    worker_args = [
        (problem_info, args.output_dir, args.model, args.error_correction) 
        for problem_info in problem_files
    ]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems in parallel
    results = []
    start_time = time.time()
    
    print(f"\nStarting parallel OptiMUS processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        # Submit all tasks
        future_to_problem = {
            executor.submit(process_single_problem_optimus, worker_arg): worker_arg[0]['database_name'] 
            for worker_arg in worker_args
        }
        
        # Collect results as they complete
        for future in as_completed(future_to_problem):
            database_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                progress.update(success=(result["status"] in ["success", "skipped"]))
                
            except Exception as exc:
                print(f'\nERROR: {database_name} generated an exception: {exc}')
                error_result = {
                    "database_name": database_name,
                    "status": "failed",
                    "error": str(exc),
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
                results.append(error_result)
                progress.update(success=False)
    
    # Generate overall summary
    total_time = time.time() - start_time
    successful = [r for r in results if r["status"] == "success"]
    skipped = [r for r in results if r["status"] == "skipped"]
    failed = [r for r in results if r["status"] == "failed"]
    
    # Calculate execution success rate among successful runs
    successful_executions = [r for r in successful if r.get("execution_success", False)]
    
    overall_summary = {
        "run_info": {
            "stage": "stage2_optimus",
            "total_problems": len(problem_files),
            "successful_runs": len(successful),
            "successful_executions": len(successful_executions),
            "skipped": len(skipped),
            "failed": len(failed),
            "execution_success_rate": f"{len(successful_executions)/len(problem_files)*100:.1f}%",
            "total_time": f"{total_time:.1f} seconds",
            "average_time_per_problem": f"{total_time/len(problem_files):.1f} seconds",
            "parallel_workers": args.max_workers,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        },
        "configuration": {
            "model": args.model,
            "error_correction": args.error_correction,
            "enhanced_problems_source": args.enhanced_problems_dir,
            "output_dir": args.output_dir,
            "max_workers": args.max_workers
        },
        "results": results
    }
    
    with open(os.path.join(args.output_dir, "stage2_optimus_summary.json"), "w") as f:
        json.dump(overall_summary, f, indent=2)
    
    # Print final summary
    print("\n" + "=" * 50)
    print("Stage 2 OptiMUS Complete!")
    print(f"Total problems processed: {len(problem_files)}")
    print(f"Successful runs: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Successful executions: {len(successful_executions)} ({len(successful_executions)/len(problem_files)*100:.1f}%)")
    print(f"Skipped (already exist): {len(skipped)} ({len(skipped)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Results saved to: {args.output_dir}")
    
    if successful_executions:
        print(f"\nSuccessful executions examples:")
        for result in successful_executions[:5]:  # Show first 5
            params_count = result.get('final_state', {}).get('parameters_count', 0)
            vars_count = result.get('final_state', {}).get('variables_count', 0)
            print(f"  - {result['database_name']}: {params_count} params, {vars_count} variables")
    
    if failed:
        print(f"\nFailed examples:")
        for result in failed[:3]:  # Show first 3
            print(f"  - {result['database_name']}: {result.get('error', 'Unknown error')}")

if __name__ == "__main__":
    main()