import asyncio
import logging
from typing import Dict, List, Any, Optional
from datetime import datetime
from tqdm.asyncio import tqdm

from models.async_base import AsyncModelInterface
from utils.math_evaluator import MathEvaluator
from utils.data_loader import get_image_paths
from utils.image_processor import prepare_images_for_model
from utils.retry_decorator import async_retry_on_failure, EvaluationRetryConfig
from prompts.math_prompts import get_text_only_math_prompt, get_image_math_prompt, get_scene_math_prompt

logger = logging.getLogger(__name__)


class AsyncEvaluationEngine:
    """Engine for running async parallel evaluations."""
    
    def __init__(
        self, 
        math_evaluator: MathEvaluator, 
        retry_config: Optional[EvaluationRetryConfig] = None
    ):
        """
        Initialize the async evaluation engine.
        
        Args:
            math_evaluator: Math evaluator instance
            retry_config: Configuration for retry logic
        """
        self.math_evaluator = math_evaluator
        self.retry_config = retry_config or EvaluationRetryConfig(
            max_retries=3,
            base_delay=2.0,
            max_delay=30.0
        )
    
    @async_retry_on_failure()
    async def _evaluate_math_with_retry(
        self,
        model: AsyncModelInterface,
        question_data: Dict[str, Any],
        mode: str,
        image_dir: str,
        prompt_mode: str
    ) -> Dict[str, Any]:
        """
        Core math evaluation logic with automatic retry capability.
        
        This method is wrapped with retry decorator to handle transient failures.
        """
        question_id = question_data["question_id"]
        original_question = question_data.get("original_question", "")
        modify_scene_related_question = question_data.get("modify_scene_related_question", "")
        ground_truth = question_data["math_ground_truth"]
        
        # Prepare inputs based on mode
        if mode == "text_only":
            prompt = get_text_only_math_prompt(original_question)
            images = None
        elif mode == "scene":
            # For scene mode, use the scene description directly
            prompt = get_scene_math_prompt(original_question)
            images = None  # Scene mode typically doesn't use images directly
        else:
            # For image modes, determine which question to use based on prompt_mode
            if prompt_mode == "explicit":
                question_to_use = modify_scene_related_question
            else:  # implicit mode
                question_to_use = ""
            
            prompt = get_image_math_prompt(question_to_use, prompt_mode)
            image_paths = get_image_paths(question_data, image_dir, mode)
            
            if not image_paths:
                raise ValueError(f"No images found for question {question_id} in {mode} mode")
            
            # Prepare images for model
            # Check if it's an API model or local model
            if "gpt" in model.name.lower():
                model_type = "openai"
            elif "claude" in model.name.lower():
                model_type = "claude"
            elif "gemini" in model.name.lower():
                model_type = "gemini"
            elif "glm" in model.name.lower() or "zhipuai" in model.name.lower():
                # ZhipuAI models (GLM series)
                model_type = "zhipuai"
            elif ("qwen" in model.name.lower() or "qvq" in model.name.lower()) and "local" not in model.name.lower():
                # API Qwen models (including QVQ-Max)
                model_type = "qwen"
            elif hasattr(model, 'api_base') and model.api_base:
                # VLLM models have api_base attribute
                model_type = "vllm"
            else:
                # Local models (including local Qwen models)
                model_type = "vllm"

            # For vLLM models, limit to first image to reduce token count

            logger.info(f"Using model_type '{model_type}' for model '{model.name}'")
            images = prepare_images_for_model(image_paths, model_type)
        
        # Get model response asynchronously
        if 'qwen' in model.name.lower():
            # For Qwen models, pass use_thinking parameter
            use_thinking = (model.name == "qwen-vl-plus")
            model_response = await model.generate_text_async(prompt, images, use_thinking=use_thinking)
        else:
            # For other models, use standard interface
            model_response = await model.generate_text_async(prompt, images)
        
        if not model_response or not model_response.strip():
            raise ValueError(f"Empty response from model {model.name}")
        
        return {
            "model_response": model_response,
            "question_id": question_id,
            "mode": mode,
            "prompt_mode": prompt_mode,
            "ground_truth": ground_truth
        }
    
    async def evaluate_math_problem_async(
        self,
        model: AsyncModelInterface,
        question_data: Dict[str, Any],
        mode: str,
        image_dir: str,
        prompt_mode: str = "implicit"
    ) -> Dict[str, Any]:
        """
        Evaluate a model on a math problem with comprehensive error handling.
        
        This method ensures that network errors and transient failures
        do not contaminate the accuracy calculations.
        """
        question_id = question_data["question_id"]
        ground_truth = question_data["math_ground_truth"]
        
        logger.info(f"Evaluating {model.name} on question {question_id} ({mode}, {prompt_mode})")
        
        result = {
            "question_id": question_id,
            "mode": mode,
            "prompt_mode": prompt_mode,
            "model": model.name,
            "math_ground_truth": ground_truth,
            "timestamp": datetime.now().isoformat()
        }
        
        try:
            # Attempt evaluation with retry logic
            eval_result = await self._evaluate_math_with_retry(
                model, question_data, mode, image_dir, prompt_mode
            )
            
            # Extract response and evaluate
            model_response = eval_result["model_response"]
            result["model_response"] = model_response
            
            # Extract answer (first stage only, no post-processing)
            extracted_answer = self.math_evaluator.extract_answer(model_response)
            result["extracted_answer"] = extracted_answer
            
            # Check correctness
            is_correct = self.math_evaluator.is_correct(extracted_answer, ground_truth)
            result["math_correct"] = is_correct
            
            logger.info(
                f"Math evaluation completed: {model.name}, {question_id}, "
                f"{mode}, {prompt_mode}, correct={is_correct}"
            )
            
        except Exception as e:
            # Log the error but mark this evaluation as failed
            logger.error(
                f"Failed to evaluate question {question_id} with {model.name} "
                f"after {self.retry_config.max_retries} retries: {e}"
            )
            
            result["error"] = str(e)
            result["evaluation_failed"] = True
            # Do NOT set math_correct to False - this will be excluded from accuracy calculation
            
        return result
    
    # Removed scene prediction evaluation - not needed
    
    async def evaluate_single_model_async(
        self,
        model: AsyncModelInterface,
        data: List[Dict[str, Any]],
        modes: List[str],
        prompt_modes: List[str],
        image_dir: str,
        concurrency: int = 10,
        save_intermediate_every: int = 50
    ) -> Dict[str, Any]:
        """
        Evaluate a single model across all specified modes asynchronously with concurrency control.
        
        Args:
            model: Async model interface
            data: List of questions to evaluate
            modes: List of evaluation modes
            prompt_modes: List of prompt modes for image evaluations
            image_dir: Directory with images
            concurrency: Maximum number of concurrent API calls for this model
            
        Returns:
            Dictionary containing all results for this model
        """
        logger.info(f"Starting async evaluation for model: {model.name} with concurrency={concurrency}")
        model_results = {}
        
        # Create semaphore to control concurrency for this model
        semaphore = asyncio.Semaphore(concurrency)
        
        async def evaluate_with_semaphore(coro):
            """Wrapper to control concurrency using semaphore."""
            async with semaphore:
                # Add small delay between requests for Qwen models to prevent overwhelming API
                if 'qwen' in model.name.lower():
                    await asyncio.sleep(0.5)  # 500ms delay between Qwen requests
                return await coro
        
        for mode in modes:
            if mode == "text_only":
                # Text-only mode - single evaluation
                key = f"{mode}"

                # Create tasks for all questions with concurrency control
                math_tasks = []
                for question_data in data:
                    task = evaluate_with_semaphore(
                        self.evaluate_math_problem_async(
                            model, question_data, mode,
                            image_dir, "text_only"
                        )
                    )
                    math_tasks.append(task)

                # Execute tasks in batches with intermediate result saving
                logger.info(f"Processing {len(math_tasks)} {mode} questions for {model.name} with max {concurrency} concurrent requests")

                # Create progress bar for this model's evaluation
                progress_bar = tqdm(
                    total=len(math_tasks),
                    desc=f"{model.name} {mode}",
                    unit="question",
                    position=1,
                    leave=False
                )

                # Execute tasks in batches with intermediate result saving every 200 questions
                math_results = []
                batch_size = min(save_intermediate_every, len(math_tasks))

                async def evaluate_with_progress(task):
                    """Wrapper to update progress bar after each question completes"""
                    try:
                        result = await task
                        progress_bar.update(1)
                        return result
                    except Exception as e:
                        progress_bar.update(1)
                        raise e

                # Initialize intermediate results accumulator
                all_processed_results = []

                for i in range(0, len(math_tasks), batch_size):
                    batch_tasks = math_tasks[i:i + batch_size]
                    batch_data = data[i:i + batch_size]
                    logger.info(f"Processing batch {i//batch_size + 1} with {len(batch_tasks)} tasks")

                    # Execute current batch
                    batch_results = await asyncio.gather(
                        *[evaluate_with_progress(task) for task in batch_tasks],
                        return_exceptions=True
                    )

                    math_results.extend(batch_results)

                    # Process batch results
                    processed_batch_results = []
                    for j, result in enumerate(batch_results):
                        if isinstance(result, Exception):
                            logger.error(f"Exception in math evaluation for {model.name}: {result}")
                            processed_batch_results.append({
                                "question_id": batch_data[j]["question_id"],
                                "mode": mode,
                                "model": model.name,
                                "error": str(result),
                                "math_correct": False
                            })
                        else:
                            processed_batch_results.append(result)

                    # Accumulate results
                    all_processed_results.extend(processed_batch_results)

                    # Save intermediate results every 200 questions
                    if len(all_processed_results) >= 200:
                        logger.info(f"Saving intermediate results after processing {len(all_processed_results)} questions")

                        intermediate_model_results = {
                            f"{mode}": {
                                "math_results": all_processed_results
                            }
                        }

                        # Save to intermediate file (same filename, overwrite)
                        from utils.async_result_writer import AsyncResultWriter

                        intermediate_writer = AsyncResultWriter("intermediate_results")
                        intermediate_filename = f"intermediate_{model.name}_{mode}"

                        try:
                            await intermediate_writer.save_model_results_async(model.name, intermediate_model_results)
                            logger.info(f"Saved intermediate results to {intermediate_filename}")
                        except Exception as save_error:
                            logger.error(f"Failed to save intermediate results: {save_error}")

                progress_bar.close()

                # Final processing of all results
                processed_math_results = []
                for i, result in enumerate(math_results):
                    if isinstance(result, Exception):
                        logger.error(f"Exception in math evaluation for {model.name}: {result}")
                        processed_math_results.append({
                            "question_id": data[i]["question_id"],
                            "mode": mode,
                            "model": model.name,
                            "error": str(result),
                            "math_correct": False
                        })
                    else:
                        processed_math_results.append(result)

                model_results[key] = {
                    "math_results": processed_math_results
                }

            elif mode == "scene":
                # Scene mode - single evaluation per scene
                key = f"{mode}"

                # Create tasks for all scene questions with concurrency control
                scene_tasks = []
                for question_data in data:
                    task = evaluate_with_semaphore(
                        self.evaluate_math_problem_async(
                            model, question_data, mode,
                            image_dir, "implicit"  # Scene mode uses implicit prompting
                        )
                    )
                    scene_tasks.append(task)

                # Execute tasks in batches with intermediate result saving every 200 scenes
                logger.info(f"Processing {len(scene_tasks)} {mode} scenes for {model.name} with max {concurrency} concurrent requests")

                # Create progress bar for this model's evaluation
                progress_bar = tqdm(
                    total=len(scene_tasks),
                    desc=f"{model.name} {mode}",
                    unit="scene",
                    position=1,
                    leave=False
                )

                scene_results = []
                batch_size = min(save_intermediate_every, len(scene_tasks))

                async def evaluate_with_progress(task):
                    """Wrapper to update progress bar after each scene completes"""
                    try:
                        result = await task
                        progress_bar.update(1)
                        return result
                    except Exception as e:
                        progress_bar.update(1)
                        raise e

                # Initialize intermediate results accumulator
                all_processed_scene_results = []

                for i in range(0, len(scene_tasks), batch_size):
                    batch_tasks = scene_tasks[i:i + batch_size]
                    batch_data = data[i:i + batch_size]
                    logger.info(f"Processing batch {i//batch_size + 1} with {len(batch_tasks)} tasks")

                    # Execute current batch
                    batch_results = await asyncio.gather(
                        *[evaluate_with_progress(task) for task in batch_tasks],
                        return_exceptions=True
                    )

                    scene_results.extend(batch_results)

                    # Process batch results
                    processed_batch_results = []
                    for j, result in enumerate(batch_results):
                        if isinstance(result, Exception):
                            logger.error(f"Exception in scene evaluation for {model.name}: {result}")
                            processed_batch_results.append({
                                "question_id": batch_data[j]["question_id"],
                                "mode": mode,
                                "model": model.name,
                                "error": str(result),
                                "math_correct": False
                            })
                        else:
                            processed_batch_results.append(result)

                    # Accumulate results
                    all_processed_scene_results.extend(processed_batch_results)

                    # Save intermediate results every 200 scenes
                    if len(all_processed_scene_results) >= 200:
                        logger.info(f"Saving intermediate results after processing {len(all_processed_scene_results)} scenes")

                        intermediate_model_results = {
                            f"{mode}": {
                                "math_results": all_processed_scene_results
                            }
                        }

                        # Save to intermediate file (same filename, overwrite)
                        from utils.async_result_writer import AsyncResultWriter

                        intermediate_writer = AsyncResultWriter("intermediate_results")
                        intermediate_filename = f"intermediate_{model.name}_{mode}"

                        try:
                            await intermediate_writer.save_model_results_async(model.name, intermediate_model_results)
                            logger.info(f"Saved intermediate results to {intermediate_filename}")
                        except Exception as save_error:
                            logger.error(f"Failed to save intermediate results: {save_error}")

                progress_bar.close()

                # Final processing of all results
                processed_scene_results = []
                for i, result in enumerate(scene_results):
                    if isinstance(result, Exception):
                        logger.error(f"Exception in scene evaluation for {model.name}: {result}")
                        processed_scene_results.append({
                            "question_id": data[i]["question_id"],
                            "mode": mode,
                            "model": model.name,
                            "error": str(result),
                            "math_correct": False
                        })
                    else:
                        processed_scene_results.append(result)

                model_results[key] = {
                    "math_results": processed_scene_results
                }
                
            else:
                # Image modes - evaluate with both prompt modes
                for prompt_mode in prompt_modes:
                    key = f"{mode}_{prompt_mode}"
                    
                    # Create tasks for math evaluation with concurrency control
                    math_tasks = []
                    
                    for question_data in data:
                        # Math evaluation task
                        math_task = evaluate_with_semaphore(
                            self.evaluate_math_problem_async(
                                model, question_data, mode, 
                                image_dir, prompt_mode
                            )
                        )
                        math_tasks.append(math_task)
                    
                    # Execute tasks in batches with intermediate result saving every 200 questions
                    logger.info(f"Processing {len(math_tasks)} {mode}_{prompt_mode} tasks for {model.name} with max {concurrency} concurrent requests")

                    # Create progress bar for this model's evaluation
                    progress_bar = tqdm(
                        total=len(math_tasks),
                        desc=f"{model.name} {mode}_{prompt_mode}",
                        unit="question",
                        position=1,
                        leave=False
                    )

                    math_results = []
                    batch_size = min(save_intermediate_every, len(math_tasks))

                    async def evaluate_with_progress(task):
                        """Wrapper to update progress bar after each question completes"""
                        try:
                            result = await task
                            progress_bar.update(1)
                            return result
                        except Exception as e:
                            progress_bar.update(1)
                            raise e

                    # Initialize intermediate results accumulator
                    all_processed_results = []

                    for i in range(0, len(math_tasks), batch_size):
                        batch_tasks = math_tasks[i:i + batch_size]
                        batch_data = data[i:i + batch_size]
                        logger.info(f"Processing batch {i//batch_size + 1} with {len(batch_tasks)} tasks")

                        # Execute current batch
                        batch_results = await asyncio.gather(
                            *[evaluate_with_progress(task) for task in batch_tasks],
                            return_exceptions=True
                        )

                        math_results.extend(batch_results)

                        # Process batch results
                        processed_batch_results = []
                        for j, result in enumerate(batch_results):
                            if isinstance(result, Exception):
                                logger.error(f"Exception in math evaluation for {model.name}: {result}")
                                processed_batch_results.append({
                                    "question_id": batch_data[j]["question_id"],
                                    "mode": mode,
                                    "prompt_mode": prompt_mode,
                                    "model": model.name,
                                    "error": str(result),
                                    "math_correct": False
                                })
                            else:
                                processed_batch_results.append(result)

                        # Accumulate results
                        all_processed_results.extend(processed_batch_results)

                        # Save intermediate results every 200 questions
                        if len(all_processed_results) >= 200:
                            logger.info(f"Saving intermediate results after processing {len(all_processed_results)} questions")

                            intermediate_model_results = {
                                f"{mode}_{prompt_mode}": {
                                    "math_results": all_processed_results
                                }
                            }

                            # Save to intermediate file (same filename, overwrite)
                            from utils.async_result_writer import AsyncResultWriter

                            intermediate_writer = AsyncResultWriter("intermediate_results")
                            intermediate_filename = f"intermediate_{model.name}_{mode}_{prompt_mode}"

                            try:
                                await intermediate_writer.save_model_results_async(model.name, intermediate_model_results)
                                logger.info(f"Saved intermediate results to {intermediate_filename}")
                            except Exception as save_error:
                                logger.error(f"Failed to save intermediate results: {save_error}")

                    progress_bar.close()

                    # Final processing of all results
                    processed_math_results = []
                    for i, result in enumerate(math_results):
                        if isinstance(result, Exception):
                            logger.error(f"Exception in math evaluation for {model.name}: {result}")
                            processed_math_results.append({
                                "question_id": data[i]["question_id"],
                                "mode": mode,
                                "prompt_mode": prompt_mode,
                                "model": model.name,
                                "error": str(result),
                                "math_correct": False
                            })
                        else:
                            processed_math_results.append(result)
                    
                    model_results[key] = {
                        "math_results": processed_math_results
                    }
    
        logger.info(f"Completed async evaluation for model: {model.name}")
        return model_results 