#!/usr/bin/env python3
"""
Stage 2: Chain-of-Experts Processing with Enhanced Problem Descriptions
Reads enhanced problem descriptions and solves them using Chain-of-Experts pipeline
"""

import os
import sys
import json
import time
import argparse
from pathlib import Path
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import threading
import shutil

# Import Chain-of-Experts modules
from main import chain_of_experts
from utils import extract_code_from_string

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

def extract_natural_language_problem(enhanced_problem_path):
    """
    Extract the natural language problem description from enhanced problem file
    Remove SQL queries and technical sections, keep only the optimization problem description
    """
    
    with open(enhanced_problem_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Split content into sections
    lines = content.split('\n')
    
    # Find the main problem sections (before "Retrieved Values")
    problem_lines = []
    skip_section = False
    
    for line in lines:
        # Stop at Retrieved Values section
        if '### Retrieved Values' in line:
            break
        
        # Skip technical sections like Database Schema
        if line.startswith('### Database Schema') or line.startswith('### Data Dictionary'):
            skip_section = True
            continue
        
        # Resume when we hit a new section
        if skip_section and line.startswith('### ') and 'database' not in line.lower() and 'schema' not in line.lower():
            skip_section = False
        
        # Add line if we're not skipping
        if not skip_section:
            problem_lines.append(line)
    
    # Join the problem description
    problem_description = '\n'.join(problem_lines).strip()
    
    # Add data from Retrieved Values section if it exists
    if '### Retrieved Values' in content:
        retrieved_section = content.split('### Retrieved Values')[1]
        
        # Extract only the CSV data, not the SQL queries
        data_lines = []
        in_csv = False
        
        for line in retrieved_section.split('\n'):
            if '```csv' in line:
                in_csv = True
                data_lines.append('\n**Data:**')
                continue
            elif '```' in line and in_csv:
                in_csv = False
                data_lines.append('')
                continue
            elif in_csv:
                data_lines.append(line)
        
        if data_lines:
            problem_description += '\n\n' + '\n'.join(data_lines)
    
    return problem_description

def find_enhanced_problem_files(enhanced_problems_dir):
    """Find all enhanced problem description files"""
    problem_files = []
    
    for problem_dir in Path(enhanced_problems_dir).iterdir():
        if problem_dir.is_dir():
            enhanced_file = problem_dir / "enhanced_problem_description.md"
            if enhanced_file.exists():
                problem_files.append({
                    'database_name': problem_dir.name,
                    'enhanced_file_path': str(enhanced_file)
                })
                print(f"Found: {problem_dir.name}")
            else:
                print(f"Skip: {problem_dir.name} (no enhanced_problem_description.md)")
    
    return problem_files

def create_dummy_code_example():
    """Create a generic code example for Chain-of-Experts"""
    return """def solve_optimization_problem():
    # Implementation to be generated by Chain-of-Experts
    # This function should solve the optimization problem
    # and return the optimal objective value
    pass"""

def process_single_problem_coe(problem_info_and_config):
    """Worker function for processing a single problem with Chain-of-Experts"""
    problem_info, output_base_dir, model_name, enable_reflection, max_collaborate_nums, max_trials = problem_info_and_config
    
    database_name = problem_info['database_name']
    enhanced_file_path = problem_info['enhanced_file_path']
    
    print(f"\n=== Chain-of-Experts Processing: {database_name} ===")
    
    # Create output directory
    output_dir = Path(output_base_dir) / database_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize detailed log for this database
    log_file = output_dir / "coe_detailed_log.txt"
    
    def log_message(message, level="INFO"):
        """Log message to both console and file"""
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = f"[{timestamp}] [{level}] {message}"
        print(f"  {message}")
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(log_entry + "\n")
    
    # Start logging
    log_message(f"Starting Chain-of-Experts processing for database: {database_name}")
    log_message(f"Enhanced file path: {enhanced_file_path}")
    log_message(f"Model: {model_name}")
    log_message(f"Reflection enabled: {enable_reflection}")
    log_message(f"Max collaborate nums: {max_collaborate_nums}")
    log_message(f"Max trials: {max_trials}")
    
    try:
        # Check if result already exists
        code_output_file = output_dir / "code_output.txt"
        if code_output_file.exists():
            log_message(f"Result already exists for {database_name}, skipping processing", "WARNING")
            return {
                "database_name": database_name,
                "status": "skipped",
                "reason": "code_output.txt already exists",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
        
        # Extract natural language problem description
        log_message("Extracting natural language problem description...")
        problem_description = extract_natural_language_problem(enhanced_file_path)
        
        # Save extracted problem for debugging
        with open(output_dir / "extracted_problem.txt", "w", encoding="utf-8") as f:
            f.write(problem_description)
        
        log_message(f"Problem description extracted: {len(problem_description)} characters")
        log_message(f"Problem preview: {problem_description[:200]}...")
        
        # Prepare problem data for Chain-of-Experts
        problem_data = {
            'description': problem_description,
            'code_example': create_dummy_code_example()
        }
        
        # Run Chain-of-Experts
        start_time = time.time()
        log_message("Starting Chain-of-Experts pipeline...")
        
        try:
            coe_answer = chain_of_experts(
                problem=problem_data,
                max_collaborate_nums=max_collaborate_nums,
                model_name=model_name,
                enable_reflection=enable_reflection,
                max_trials=max_trials
            )
            
            processing_time = time.time() - start_time
            log_message(f"Chain-of-Experts pipeline completed in {processing_time:.2f} seconds")
            log_message(f"Generated answer length: {len(coe_answer)} characters")
            
        except Exception as e:
            processing_time = time.time() - start_time
            log_message(f"Chain-of-Experts pipeline failed after {processing_time:.2f} seconds: {str(e)}", "ERROR")
            raise e
        
        # Extract generated code
        log_message("Extracting Python code from Chain-of-Experts answer...")
        generated_code = extract_code_from_string(coe_answer)
        
        if generated_code:
            log_message(f"Successfully extracted code: {len(generated_code)} characters")
            log_message(f"Code preview: {generated_code[:150]}...")
        else:
            log_message("No valid Python code found in Chain-of-Experts answer", "WARNING")
        
        # Save Chain-of-Experts full answer
        with open(output_dir / "coe_full_answer.txt", "w", encoding="utf-8") as f:
            f.write(coe_answer)
        log_message("Saved full Chain-of-Experts answer to coe_full_answer.txt")
        
        # Save generated code as Python file
        with open(output_dir / "generated_code.py", "w", encoding="utf-8") as f:
            f.write(generated_code)
        log_message("Saved generated code to generated_code.py")
        
        # Copy enhanced problem description to output for reference
        shutil.copy2(enhanced_file_path, str(output_dir / "enhanced_problem_description.md"))
        log_message("Copied enhanced problem description for reference")
        
        # Execute generated code to get optimization result
        execution_success = False
        optimal_value = None
        execution_output = ""
        
        if generated_code and len(generated_code.strip()) > 50:
            log_message("Generated code appears valid, attempting execution...")
            
            try:
                # Use Gurobi environment for execution
                gurobi_python = "/dccstor/nl2opt/miniforge3/envs/nl2opt_optim/bin/python"
                log_message(f"Using Gurobi Python environment: {gurobi_python}")
                
                # Create a temporary execution script
                exec_script = f"""
import sys
import os
import traceback

# Set up paths
sys.path.append('.')

# Generated code
{generated_code}

# Try to find and execute the main function
try:
    import re
    import types
    
    # Execute the code to define functions
    exec_globals = {{}}
    exec('''{generated_code}''', exec_globals)
    
    # Find all function names in the executed code
    functions = [name for name, obj in exec_globals.items() 
                if callable(obj) and not name.startswith('_') and name not in ['GRB', 'quicksum', 'Model']]
    
    # Look for optimization-related functions
    target_functions = []
    for func_name in functions:
        if any(keyword in func_name.lower() for keyword in 
               ['solve', 'optim', 'coffee', 'allergy', 'dorm', 'assign', 'cinema', 'candidate']):
            target_functions.append(func_name)
    
    if target_functions:
        # Call the first matching function
        func_name = target_functions[0]
        result = exec_globals[func_name]()
        
        # Handle Gurobi model object
        if hasattr(result, 'objVal') and hasattr(result, 'optimize'):
            # This is a Gurobi model that needs optimization
            result.optimize()
            if result.status == 2:  # OPTIMAL
                print(f"Optimal Objective Value: {{result.objVal}}")
                print(f"Optimal Objective Value: {{result.objVal}}")  # Double print like baseline
            else:
                print(f"Model status: {{result.status}} (not optimal)")
        elif hasattr(result, 'objVal'):
            # Already optimized model
            print(f"Optimal Objective Value: {{result.objVal}}")
            print(f"Optimal Objective Value: {{result.objVal}}")  # Double print like baseline
        elif isinstance(result, (int, float)):
            print(f"Optimal Objective Value: {{result}}")
            print(f"Optimal Objective Value: {{result}}")  # Double print like baseline
        else:
            print(f"Unexpected result type: {{type(result)}}")
            
    elif functions:
        # Try calling any available function
        func_name = functions[0]
        result = exec_globals[func_name]()
        
        if hasattr(result, 'objVal') and hasattr(result, 'optimize'):
            result.optimize()
            if result.status == 2:  # OPTIMAL
                print(f"Optimal Objective Value: {{result.objVal}}")
                print(f"Optimal Objective Value: {{result.objVal}}")
            else:
                print(f"Model status: {{result.status}}")
        elif hasattr(result, 'objVal'):
            print(f"Optimal Objective Value: {{result.objVal}}")
            print(f"Optimal Objective Value: {{result.objVal}}")
        elif isinstance(result, (int, float)):
            print(f"Optimal Objective Value: {{result}}")
            print(f"Optimal Objective Value: {{result}}")
        else:
            print(f"Function result: {{result}}")
    else:
        print("No suitable functions found in the code")
        
except Exception as e:
    print(f"Execution error: {{str(e)}}")
    traceback.print_exc()
"""
                
                # Write execution script
                exec_script_path = output_dir / "execute_code.py"
                with open(exec_script_path, "w") as f:
                    f.write(exec_script)
                log_message("Created execution script: execute_code.py")
                
                # Execute with timeout
                import subprocess
                log_message("Starting code execution with 120 second timeout...")
                exec_start_time = time.time()
                
                result = subprocess.run(
                    [gurobi_python, str(exec_script_path)],
                    capture_output=True,
                    text=True,
                    timeout=120,  # 2 minutes timeout
                    cwd=os.getcwd()  # 使用当前工作目录
                )

                exec_duration = time.time() - exec_start_time
                execution_output = result.stdout + result.stderr
                
                log_message(f"Code execution completed in {exec_duration:.2f} seconds")
                log_message(f"Return code: {result.returncode}")
                
                if result.stdout:
                    log_message(f"STDOUT: {result.stdout[:300]}...")
                if result.stderr:
                    log_message(f"STDERR: {result.stderr[:300]}...", "WARNING")
                
                # Parse optimal value from output
                if result.returncode == 0:
                    execution_success = True
                    # Look for optimal value in output with more patterns
                    import re
                    
                    # Try multiple patterns for optimal value
                    value_patterns = [
                        r'Optimal Objective Value:\s*([\d.-]+)',
                        r'objVal:\s*([\d.-]+)',
                        r'Optimal Value:\s*([\d.-]+)',
                        r'Objective Value:\s*([\d.-]+)',
                        r'Result:\s*([\d.-]+)',
                    ]
                    
                    optimal_value = None
                    for pattern in value_patterns:
                        matches = re.findall(pattern, execution_output)
                        if matches:
                            try:
                                optimal_value = float(matches[0])
                                log_message(f"Successfully extracted optimal value using pattern '{pattern}': {optimal_value}")
                                break
                            except ValueError:
                                continue
                    
                    if optimal_value is None:
                        log_message("Execution successful but no clear optimal value found", "WARNING")
                        log_message(f"Full execution output for debugging: {execution_output}")
                        
                        # Try to find any numeric values as potential results
                        numeric_matches = re.findall(r'([\d.-]+)', execution_output)
                        if numeric_matches:
                            log_message(f"Found numeric values in output: {numeric_matches[:5]}")  # Show first 5
                else:
                    log_message(f"Execution failed with return code {result.returncode}", "ERROR")
                
            except subprocess.TimeoutExpired:
                execution_output = "Execution timeout (>120s)"
                log_message("Code execution timed out after 120 seconds", "ERROR")
                
            except Exception as e:
                execution_output = f"Execution setup error: {str(e)}"
                log_message(f"Execution setup error: {str(e)}", "ERROR")
        
        else:
            execution_output = "No valid code generated"
            log_message("No valid code generated - skipping execution", "WARNING")
        
        # Save results in the same format as other baselines
        log_message("Preparing results summary...")
        stage2_results = {
            "database_name": database_name,
            "status": "success" if execution_success else "failed",
            "execution_success": execution_success,
            "optimal_value": optimal_value,
            "processing_time": processing_time,
            "model_used": model_name,
            "enable_reflection": enable_reflection,
            "max_collaborate_nums": max_collaborate_nums,
            "max_trials": max_trials,
            "enhanced_file_source": enhanced_file_path,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "final_state": {
                "coe_answer_length": len(coe_answer),
                "generated_code_length": len(generated_code),
                "has_generated_code": bool(generated_code.strip()),
                "execution_output": execution_output[:500]  # First 500 chars
            }
        }
        
        with open(output_dir / "stage2_coe_results.json", "w", encoding="utf-8") as f:
            json.dump(stage2_results, f, indent=2)
        log_message("Saved detailed results to stage2_coe_results.json")
        
        # Save code output in the expected format for evaluation (matching other baselines)
        log_message("Creating standardized output file...")
        with open(code_output_file, "w", encoding="utf-8") as f:
            if execution_success and optimal_value is not None:
                # Format like OptiMUS/OR-LLM-Agent - just the Gurobi output and optimal value
                f.write(execution_output)  # This contains the full Gurobi optimization log
            elif execution_success:
                f.write(execution_output)  # Write whatever output we got
            else:
                f.write(f"ERROR: Chain-of-Experts execution failed\n")
                f.write(f"Model: {model_name}\n")
                f.write(f"Processing time: {processing_time:.2f} seconds\n")
                f.write(f"Execution output: {execution_output}\n")
        
        # Save generated code to separate file (as requested)
        generated_code_detailed_file = output_dir / "generated_code_detailed.py"
        with open(generated_code_detailed_file, "w", encoding="utf-8") as f:
            f.write(f"# Chain-of-Experts Generated Code\n")
            f.write(f"# Processing completed successfully\n")
            f.write(f"# Model: {model_name}\n")
            f.write(f"# Processing time: {processing_time:.2f} seconds\n")
            f.write(f"# Database: {database_name}\n\n")
            f.write(generated_code)
            
        log_message("Saved generated code details to generated_code_detailed.py")
        
        # Final status log
        final_status = "SUCCESS" if execution_success else "FAILED"
        final_message = f"Processing completed: {final_status}"
        if optimal_value is not None:
            final_message += f" - Optimal Value: {optimal_value}"
        log_message(final_message)
        log_message(f"All output files saved to: {output_dir}")
        
        print(f"  {final_status}: {database_name} - Optimal Value: {optimal_value}")
        
        return stage2_results
        
    except Exception as e:
        log_message(f"Critical error during processing: {str(e)}", "ERROR")
        print(f"  ERROR: {database_name} failed: {e}")
        
        error_summary = {
            "database_name": database_name,
            "status": "failed",
            "error": str(e),
            "model_used": model_name,
            "enable_reflection": enable_reflection,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        with open(output_dir / "stage2_coe_error.json", "w") as f:
            json.dump(error_summary, f, indent=2)
        
        # Create empty code_output.txt to indicate failure for evaluation scripts
        with open(output_dir / "code_output.txt", "w") as f:
            f.write(f"ERROR: Chain-of-Experts processing failed - {str(e)}\n")
        
        log_message(f"Error summary saved to stage2_coe_error.json", "ERROR")
        return error_summary

def main():
    parser = argparse.ArgumentParser(description="Stage 2: Chain-of-Experts Processing with Enhanced Problems")
    parser.add_argument("--enhanced_problems_dir", type=str, required=True,
                       help="Directory containing enhanced problem descriptions from Stage 1")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for Chain-of-Experts results")
    parser.add_argument("--model", type=str, default="deepseek-ai/DeepSeek-V3",
                   help="Model name to use for LLM queries")
    parser.add_argument("--enable_reflection", action='store_true',
                       help="Enable backward reflection mechanism in Chain-of-Experts")
    parser.add_argument("--max_collaborate_nums", type=int, default=3,
                       help="Maximum number of expert collaborations per trial")
    parser.add_argument("--max_trials", type=int, default=1,
                       help="Maximum number of forward-backward trials")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--max_workers", type=int, default=None,
                       help="Maximum number of parallel workers")
    
    args = parser.parse_args()
    
    # Set default number of workers
    if args.max_workers is None:
        args.max_workers = min(mp.cpu_count(), 4)  # Conservative for Chain-of-Experts
    
    print("Stage 2: Chain-of-Experts Processing")
    print("=" * 50)
    print(f"Enhanced problems directory: {args.enhanced_problems_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {args.model}")
    print(f"Enable reflection: {args.enable_reflection}")
    print(f"Max collaborate nums: {args.max_collaborate_nums}")
    print(f"Max trials: {args.max_trials}")
    print(f"Max parallel workers: {args.max_workers}")
    
    # Find enhanced problem files
    problem_files = find_enhanced_problem_files(args.enhanced_problems_dir)
    
    if not problem_files:
        print("ERROR: No enhanced problem descriptions found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} enhanced problem descriptions")
    
    # Limit for testing
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to first {args.max_problems} problems for testing")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare arguments for parallel processing
    worker_args = [
        (problem_info, args.output_dir, args.model, args.enable_reflection, 
         args.max_collaborate_nums, args.max_trials) 
        for problem_info in problem_files
    ]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems in parallel
    results = []
    start_time = time.time()
    
    print(f"\nStarting parallel Chain-of-Experts processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        # Submit all tasks
        future_to_problem = {
            executor.submit(process_single_problem_coe, worker_arg): worker_arg[0]['database_name'] 
            for worker_arg in worker_args
        }
        
        # Collect results as they complete
        for future in as_completed(future_to_problem):
            database_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                progress.update(success=(result["status"] in ["success", "skipped"]))
                
            except Exception as exc:
                print(f'\nERROR: {database_name} generated an exception: {exc}')
                error_result = {
                    "database_name": database_name,
                    "status": "failed",
                    "error": str(exc),
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
                results.append(error_result)
                progress.update(success=False)
    
    # Generate overall summary (matching other baselines format)
    total_time = time.time() - start_time
    successful = [r for r in results if r["status"] == "success"]
    skipped = [r for r in results if r["status"] == "skipped"]
    failed = [r for r in results if r["status"] == "failed"]
    
    # Calculate execution success rate among successful runs
    successful_executions = [r for r in successful if r.get("execution_success", False)]
    
    overall_summary = {
        "run_info": {
            "stage": "stage2_chain_of_experts",
            "total_problems": len(problem_files),
            "successful_runs": len(successful),
            "successful_executions": len(successful_executions),
            "skipped": len(skipped),
            "failed": len(failed),
            "execution_success_rate": f"{len(successful_executions)/len(problem_files)*100:.1f}%",
            "total_time": f"{total_time:.1f} seconds",
            "average_time_per_problem": f"{total_time/len(problem_files):.1f} seconds",
            "parallel_workers": args.max_workers,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        },
        "configuration": {
            "model": args.model,
            "enable_reflection": args.enable_reflection,
            "max_collaborate_nums": args.max_collaborate_nums,
            "max_trials": args.max_trials,
            "enhanced_problems_source": args.enhanced_problems_dir,
            "output_dir": args.output_dir,
            "max_workers": args.max_workers
        },
        "results": results
    }
    
    with open(os.path.join(args.output_dir, "stage2_coe_summary.json"), "w") as f:
        json.dump(overall_summary, f, indent=2)
    
    # Print final summary (matching other baselines format)
    print("\n" + "=" * 50)
    print("Stage 2 Chain-of-Experts Complete!")
    print(f"Total problems processed: {len(problem_files)}")
    print(f"Successful runs: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Successful executions: {len(successful_executions)} ({len(successful_executions)/len(problem_files)*100:.1f}%)")
    print(f"Skipped (already exist): {len(skipped)} ({len(skipped)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Results saved to: {args.output_dir}")
    
    if successful_executions:
        print(f"\nSuccessful executions examples:")
        for result in successful_executions[:5]:  # Show first 5
            code_length = result.get('final_state', {}).get('generated_code_length', 0)
            answer_length = result.get('final_state', {}).get('coe_answer_length', 0)
            print(f"  - {result['database_name']}: {code_length} chars code, {answer_length} chars answer")
    
    if failed:
        print(f"\nFailed examples:")
        for result in failed[:3]:  # Show first 3
            print(f"  - {result['database_name']}: {result.get('error', 'Unknown error')}")

if __name__ == "__main__":
    main()