"""Enhanced parallel evaluation framework with async support."""

import asyncio
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional, Callable
from pathlib import Path
import logging
from datetime import datetime
import json
import traceback

from ..config import EvaluationTask, EvaluationResult


class ParallelEvaluator:
    """Enhanced parallel evaluator with rate limiting and retry logic."""
    
    def __init__(
        self,
        max_workers: int = 4,
        use_async: bool = True,
        rate_limit: Optional[Dict[str, float]] = None,
        logger: Optional[logging.Logger] = None
    ):
        """Initialize parallel evaluator.
        
        Args:
            max_workers: Maximum number of parallel workers
            use_async: Whether to use async execution
            rate_limit: Rate limits per model provider (requests per second)
            logger: Logger instance
        """
        self.max_workers = max_workers
        self.use_async = use_async
        self.rate_limit = rate_limit or {}
        self.logger = logger or logging.getLogger(__name__)
        
        # Track rate limiting
        self.last_request_time: Dict[str, float] = {}
        self.request_semaphores: Dict[str, asyncio.Semaphore] = {}
        
        # Progress tracking
        self.completed_tasks: set = set()
        self.failed_tasks: set = set()
        self.progress_callback: Optional[Callable] = None
        

    async def _rate_limit_request(self, provider: str):
        """Apply rate limiting for a provider."""
        if provider not in self.rate_limit:
            return
        
        limit = self.rate_limit[provider]
        min_interval = 1.0 / limit  # Convert to minimum interval between requests
        
        # Initialize semaphore if needed
        if provider not in self.request_semaphores:
            # Allow burst of up to 10 requests
            self.request_semaphores[provider] = asyncio.Semaphore(min(10, int(limit)))
        
        async with self.request_semaphores[provider]:
            # Check time since last request
            current_time = time.time()
            if provider in self.last_request_time:
                elapsed = current_time - self.last_request_time[provider]
                if elapsed < min_interval:
                    await asyncio.sleep(min_interval - elapsed)
            
            self.last_request_time[provider] = time.time()
    
    async def _evaluate_task_async(self, task: EvaluationTask, evaluator_func: Callable) -> EvaluationResult:
        """Evaluate a single task asynchronously."""
        start_time = time.time()
        
        # Skip if already completed
        if task.task_id in self.completed_tasks:
            self.logger.debug(f"Skipping completed task: {task.task_id}")
            return EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=True,
                metrics={'skipped': True}
            )
        
        try:
            # Apply rate limiting
            await self._rate_limit_request(task.model_config.provider)
            
            # Run evaluation in thread pool to avoid blocking
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                None,  # Use default executor
                evaluator_func,
                task
            )
            
            # Mark as completed
            self.completed_tasks.add(task.task_id)
            if task.task_id in self.failed_tasks:
                self.failed_tasks.remove(task.task_id)
            
            # Update progress
            if self.progress_callback:
                self.progress_callback(len(self.completed_tasks), len(self.failed_tasks))
            
            
            result.duration = time.time() - start_time
            return result
            
        except Exception as e:
            self.logger.error(f"Error evaluating task {task.task_id}: {str(e)}")
            self.failed_tasks.add(task.task_id)
            
            return EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=False,
                error=str(e),
                duration=time.time() - start_time,
                attempt=task.attempt
            )
    
    def _evaluate_task_sync(self, task: EvaluationTask, evaluator_func: Callable) -> EvaluationResult:
        """Evaluate a single task synchronously."""
        start_time = time.time()
        
        # Skip if already completed
        if task.task_id in self.completed_tasks:
            self.logger.debug(f"Skipping completed task: {task.task_id}")
            return EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=True,
                metrics={'skipped': True}
            )
        
        try:
            result = evaluator_func(task)
            
            # Mark as completed
            self.completed_tasks.add(task.task_id)
            if task.task_id in self.failed_tasks:
                self.failed_tasks.remove(task.task_id)
            
            result.duration = time.time() - start_time
            return result
            
        except Exception as e:
            self.logger.error(f"Error evaluating task {task.task_id}: {str(e)}\n{traceback.format_exc()}")
            self.failed_tasks.add(task.task_id)
            
            return EvaluationResult(
                task_id=task.task_id,
                problem_index=task.problem_index,
                model_name=task.model_config.name,
                success=False,
                error=str(e),
                duration=time.time() - start_time,
                attempt=task.attempt
            )
    
    async def evaluate_batch_async(
        self,
        tasks: List[EvaluationTask],
        evaluator_func: Callable
    ) -> List[EvaluationResult]:
        """Evaluate a batch of tasks asynchronously."""
        self.logger.info(f"Starting async evaluation of {len(tasks)} tasks")
        
        # Create tasks for async execution
        async_tasks = [
            self._evaluate_task_async(task, evaluator_func)
            for task in tasks
        ]
        
        # Run with limited concurrency
        results = []
        for i in range(0, len(async_tasks), self.max_workers):
            batch = async_tasks[i:i + self.max_workers]
            batch_results = await asyncio.gather(*batch, return_exceptions=True)
            
            for result in batch_results:
                if isinstance(result, Exception):
                    self.logger.error(f"Task failed with exception: {result}")
                    # Create error result
                    task_idx = i + batch_results.index(result)
                    results.append(EvaluationResult(
                        task_id=tasks[task_idx].task_id,
                        problem_index=tasks[task_idx].problem_index,
                        model_name=tasks[task_idx].model_config.name,
                        success=False,
                        error=str(result)
                    ))
                else:
                    results.append(result)
        
        return results
    
    def evaluate_batch_sync(
        self,
        tasks: List[EvaluationTask],
        evaluator_func: Callable,
        use_processes: bool = True
    ) -> List[EvaluationResult]:
        """Evaluate a batch of tasks synchronously with parallel processing."""
        self.logger.info(f"Starting sync evaluation of {len(tasks)} tasks with {self.max_workers} workers")
        
        results = []
        
        # Choose executor type
        executor_class = ProcessPoolExecutor if use_processes else ThreadPoolExecutor
        
        with executor_class(max_workers=self.max_workers) as executor:
            # Submit all tasks
            future_to_task = {
                executor.submit(self._evaluate_task_sync, task, evaluator_func): task
                for task in tasks
            }
            
            # Process completed tasks
            for future in as_completed(future_to_task):
                task = future_to_task[future]
                try:
                    result = future.result()
                    results.append(result)
                    
                    # Progress update
                    if self.progress_callback:
                        self.progress_callback(len(self.completed_tasks), len(self.failed_tasks))
                    
                except Exception as e:
                    self.logger.error(f"Task {task.task_id} failed: {e}")
                    results.append(EvaluationResult(
                        task_id=task.task_id,
                        problem_index=task.problem_index,
                        model_name=task.model_config.name,
                        success=False,
                        error=str(e)
                    ))
        
        
        return results
    
    def evaluate_batch(
        self,
        tasks: List[EvaluationTask],
        evaluator_func: Callable
    ) -> List[EvaluationResult]:
        """Evaluate a batch of tasks using configured method."""
        if self.use_async:
            # Run async evaluation
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                return loop.run_until_complete(
                    self.evaluate_batch_async(tasks, evaluator_func)
                )
            finally:
                loop.close()
        else:
            return self.evaluate_batch_sync(tasks, evaluator_func)
    
    def set_progress_callback(self, callback: Callable[[int, int], None]):
        """Set callback for progress updates.
        
        Args:
            callback: Function that receives (completed_count, failed_count)
        """
        self.progress_callback = callback
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get evaluation statistics."""
        return {
            'completed': len(self.completed_tasks),
            'failed': len(self.failed_tasks),
            'total': len(self.completed_tasks) + len(self.failed_tasks),
            'success_rate': len(self.completed_tasks) / max(1, len(self.completed_tasks) + len(self.failed_tasks))
        }
