#!/usr/bin/env python3
"""
Stage 2: Simple Zero-Shot Gurobi Code Generation
Reads enhanced problem descriptions and generates Gurobi code via zero-shot prompting
"""

import os
import sys
import json
import time
import argparse
from pathlib import Path
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import threading

# Import the core solver
from simple_zero_shot_solver import SimpleZeroShotSolver

class ProgressTracker:
    """Thread-safe progress tracker"""
    def __init__(self, total_tasks):
        self.total_tasks = total_tasks
        self.completed = 0
        self.successful = 0
        self.failed = 0
        self.lock = threading.Lock()
    
    def update(self, success=True):
        with self.lock:
            self.completed += 1
            if success:
                self.successful += 1
            else:
                self.failed += 1
            
            print(f"[Progress] {self.completed}/{self.total_tasks} completed "
                  f"(Success: {self.successful}, Failed: {self.failed})")

def find_enhanced_problem_files(enhanced_problems_dir):
    """Find all enhanced problem description files"""
    problem_files = []
    
    for problem_dir in Path(enhanced_problems_dir).iterdir():
        if problem_dir.is_dir():
            enhanced_file = problem_dir / "enhanced_problem_description.md"
            if enhanced_file.exists():
                problem_files.append({
                    'database_name': problem_dir.name,
                    'enhanced_file_path': str(enhanced_file)
                })
                print(f"Found: {problem_dir.name}")
            else:
                print(f"Skip: {problem_dir.name} (no enhanced_problem_description.md)")
    
    return problem_files

def process_single_problem_zero_shot(problem_info_and_config):
    """Worker function for processing a single problem with Simple Zero-Shot"""
    problem_info, output_base_dir, model_name, temperature = problem_info_and_config
    
    database_name = problem_info['database_name']
    enhanced_file_path = problem_info['enhanced_file_path']
    
    print(f"\n=== Simple Zero-Shot Processing: {database_name} ===")
    
    # Create output directory
    output_dir = Path(output_base_dir) / database_name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create log file
    log_file = output_dir / "zero_shot_log.txt"
    
    try:
        # Check if result already exists
        code_output_file = output_dir / "code_output.txt"
        if code_output_file.exists():
            print(f"  Result already exists for {database_name}, skipping...")
            return {
                "database_name": database_name,
                "status": "skipped",
                "reason": "code_output.txt already exists",
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
        
        # Read enhanced problem description
        with open(enhanced_file_path, 'r', encoding='utf-8') as f:
            enhanced_problem_text = f.read()
        
        print(f"  Loaded enhanced problem description for {database_name} ({len(enhanced_problem_text)} chars)")
        
        # Initialize solver
        solver = SimpleZeroShotSolver(
            model_name=model_name,
            temperature=temperature,
            log_file=str(log_file)
        )
        
        # Run zero-shot solving
        start_time = time.time()
        result = solver.solve_problem(enhanced_problem_text)
        processing_time = time.time() - start_time
        
        # Save generated code
        if result.get('generated_code'):
            with open(output_dir / "generated_gurobi_code.py", "w", encoding="utf-8") as f:
                f.write(result['generated_code'])
        
        # Save execution output
        if result.get('execution_output'):
            with open(output_dir / "gurobi_execution_output.txt", "w", encoding="utf-8") as f:
                f.write(result['execution_output'])
        
        # Save results summary
        stage2_results = {
            "database_name": database_name,
            "status": "success" if result['success'] else "failed",
            "solve_success": result['success'],
            "optimal_value": result.get('optimal_value'),
            "processing_time": processing_time,
            "model_used": model_name,
            "temperature": temperature,
            "enhanced_file_source": enhanced_file_path,
            "error_message": result.get('error_message'),
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        with open(output_dir / "stage2_zero_shot_results.json", "w", encoding="utf-8") as f:
            json.dump(stage2_results, f, indent=2)
        
        # Save code output in the expected format for accuracy evaluation
        with open(code_output_file, "w", encoding="utf-8") as f:
            if result['success'] and result.get('optimal_value') is not None:
                f.write(f"Optimal Objective Value: {result['optimal_value']}\n")
                if result.get('execution_output'):
                    f.write("\n" + "="*50 + "\n")
                    f.write("Full Gurobi Output:\n")
                    f.write("="*50 + "\n")
                    f.write(result['execution_output'])
            else:
                f.write(f"ERROR: Zero-shot optimization failed.\n")
                if result.get('error_message'):
                    f.write(f"Error: {result['error_message']}\n")
                f.write(f"Model: {model_name}\n")
                f.write(f"Processing time: {processing_time:.2f} seconds\n")
        
        print(f"  {'SUCCESS' if result['success'] else 'FAILED'}: {database_name} - Result: {result.get('optimal_value', 'N/A')}")
        
        return stage2_results
        
    except Exception as e:
        print(f"  ERROR: {database_name} failed: {e}")
        
        error_summary = {
            "database_name": database_name,
            "status": "failed",
            "error": str(e),
            "model_used": model_name,
            "temperature": temperature,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        with open(output_dir / "stage2_zero_shot_error.json", "w") as f:
            json.dump(error_summary, f, indent=2)
        
        # Create error code_output.txt
        with open(output_dir / "code_output.txt", "w") as f:
            f.write(f"ERROR: {str(e)}\n")
        
        return error_summary

def main():
    parser = argparse.ArgumentParser(description="Stage 2: Simple Zero-Shot Gurobi Code Generation")
    parser.add_argument("--enhanced_problems_dir", type=str, required=True,
                       help="Directory containing enhanced problem descriptions from Stage 1")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for Simple Zero-Shot results")
    parser.add_argument("--model", type=str, default="DeepSeek-V3",
                       help="Model name to use for LLM queries")
    parser.add_argument("--temperature", type=float, default=0.1,
                       help="Temperature for LLM generation")
    parser.add_argument("--max_problems", type=int, default=None,
                       help="Maximum number of problems to process")
    parser.add_argument("--max_workers", type=int, default=None,
                       help="Maximum number of parallel workers")
    
    args = parser.parse_args()
    
    # Set default number of workers
    if args.max_workers is None:
        args.max_workers = min(mp.cpu_count(), 8)
    
    print("Stage 2: Simple Zero-Shot Gurobi Code Generation")
    print("=" * 60)
    print(f"Enhanced problems directory: {args.enhanced_problems_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {args.model}")
    print(f"Temperature: {args.temperature}")
    print(f"Max parallel workers: {args.max_workers}")
    
    # Find enhanced problem files
    problem_files = find_enhanced_problem_files(args.enhanced_problems_dir)
    
    if not problem_files:
        print("ERROR: No enhanced problem descriptions found!")
        sys.exit(1)
    
    print(f"\nFound {len(problem_files)} enhanced problem descriptions")
    
    # Limit for testing
    if args.max_problems:
        problem_files = problem_files[:args.max_problems]
        print(f"Limited to first {args.max_problems} problems for testing")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare arguments for parallel processing
    worker_args = [
        (problem_info, args.output_dir, args.model, args.temperature) 
        for problem_info in problem_files
    ]
    
    # Initialize progress tracker
    progress = ProgressTracker(len(problem_files))
    
    # Process problems in parallel
    results = []
    start_time = time.time()
    
    print(f"\nStarting parallel Simple Zero-Shot processing with {args.max_workers} workers...")
    
    with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
        # Submit all tasks
        future_to_problem = {
            executor.submit(process_single_problem_zero_shot, worker_arg): worker_arg[0]['database_name'] 
            for worker_arg in worker_args
        }
        
        # Collect results as they complete
        for future in as_completed(future_to_problem):
            database_name = future_to_problem[future]
            try:
                result = future.result()
                results.append(result)
                progress.update(success=(result["status"] in ["success", "skipped"]))
                
            except Exception as exc:
                print(f'\nERROR: {database_name} generated an exception: {exc}')
                error_result = {
                    "database_name": database_name,
                    "status": "failed",
                    "error": str(exc),
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
                results.append(error_result)
                progress.update(success=False)
    
    # Generate overall summary
    total_time = time.time() - start_time
    successful = [r for r in results if r["status"] == "success"]
    skipped = [r for r in results if r["status"] == "skipped"]
    failed = [r for r in results if r["status"] == "failed"]
    
    # Calculate solve success rate among successful runs
    successful_solves = [r for r in successful if r.get("solve_success", False)]
    
    overall_summary = {
        "run_info": {
            "stage": "stage2_simple_zero_shot",
            "total_problems": len(problem_files),
            "successful_runs": len(successful),
            "successful_solves": len(successful_solves),
            "skipped": len(skipped),
            "failed": len(failed),
            "solve_success_rate": f"{len(successful_solves)/len(problem_files)*100:.1f}%",
            "total_time": f"{total_time:.1f} seconds",
            "average_time_per_problem": f"{total_time/len(problem_files):.1f} seconds",
            "parallel_workers": args.max_workers,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        },
        "configuration": {
            "model": args.model,
            "temperature": args.temperature,
            "enhanced_problems_source": args.enhanced_problems_dir,
            "output_dir": args.output_dir,
            "max_workers": args.max_workers
        },
        "results": results
    }
    
    with open(os.path.join(args.output_dir, "stage2_simple_zero_shot_summary.json"), "w") as f:
        json.dump(overall_summary, f, indent=2)
    
    # Print final summary
    print("\n" + "=" * 60)
    print("Stage 2 Simple Zero-Shot Complete!")
    print(f"Total problems processed: {len(problem_files)}")
    print(f"Successful runs: {len(successful)} ({len(successful)/len(problem_files)*100:.1f}%)")
    print(f"Successful solves: {len(successful_solves)} ({len(successful_solves)/len(problem_files)*100:.1f}%)")
    print(f"Skipped (already exist): {len(skipped)} ({len(skipped)/len(problem_files)*100:.1f}%)")
    print(f"Failed: {len(failed)} ({len(failed)/len(problem_files)*100:.1f}%)")
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Results saved to: {args.output_dir}")
    
    if successful_solves:
        print(f"\nSuccessful solves examples:")
        for result in successful_solves[:5]:  # Show first 5
            print(f"  - {result['database_name']}: {result.get('optimal_value', 'N/A')}")
    
    if failed:
        print(f"\nFailed examples:")
        for result in failed[:3]:  # Show first 3
            print(f"  - {result['database_name']}: {result.get('error', 'Unknown error')}")

if __name__ == "__main__":
    main()