"""
Main Evaluator - Coordinates all evaluators for comprehensive assessment
"""

from typing import Dict, Any, List, Tuple
from .base_evaluator import BaseEvaluator
from .format_evaluator import FormatEvaluator
from .commonsense_evaluator import CommonsenseEvaluator
from .soft_constraint_evaluator import SoftConstraintEvaluator
from .preference_evaluator import PreferenceEvaluator
from .user_request_evaluator import UserRequestEvaluator
from utils.itinerary_parser import ItineraryParser
from utils.poi_analyzer import POIAnalyzer
from utils.llms import Gemini, EmptyLLM, AbstractLLM


class MainEvaluator:
    """Main Evaluator class"""
    
    def __init__(self, format_reward: int = 1, answer_reward: float = 1.0, enable_user_request_eval: bool = False, enable_LLM: bool =False, llm: AbstractLLM = None, poi_analyzer: POIAnalyzer = None):
        """
        Initialize main evaluator
        
        Args:
            format_reward: Format reward score
            answer_reward: Answer reward score
            enable_user_request_eval: Whether to enable user request constraint evaluator
            poi_analyzer: POI analyzer instance, create new instance if None
        """
        self.format_reward = format_reward
        self.answer_reward = answer_reward
        self.enable_user_request_eval = enable_user_request_eval
        
        # Unified POI analyzer instance
        self.poi_analyzer = poi_analyzer if poi_analyzer is not None else POIAnalyzer(use_api=False)
        
        # Initialize each evaluator, passing the shared poi_analyzer
        self.format_evaluator = FormatEvaluator(format_reward, self.poi_analyzer)
        self.commonsense_evaluator = CommonsenseEvaluator(self.poi_analyzer)
        self.soft_constraint_evaluator = SoftConstraintEvaluator(self.poi_analyzer, enable_LLM=enable_LLM, llm=llm)
        self.preference_evaluator = PreferenceEvaluator(poi_analyzer=self.poi_analyzer)
        
        # Optional user request evaluator
        self.user_request_evaluator = None
        if enable_user_request_eval:
            self.user_request_evaluator = UserRequestEvaluator(weight=0.2, llm=llm)
        
        self.evaluators = [
            self.format_evaluator,
            self.commonsense_evaluator,
            self.soft_constraint_evaluator,
            self.preference_evaluator
        ]
        
        if self.user_request_evaluator:
            self.evaluators.append(self.user_request_evaluator)
    
    def compute_score(self, solution_str: str, poi_dict: str) -> Tuple[float, Dict[str, Any]]:
        """
        Calculate comprehensive score
        
        Args:
            solution_str: Raw model response string
            poi_dict: POI dictionary string
            
        Returns:
            Tuple[float, Dict]: (Total score, detailed evaluation results)
        """
        # Parse POI dictionary
        poi_dict_json = self._parse_poi_dict(poi_dict)
        if poi_dict_json is None:
            return -2, {"error": "POI dictionary parsing failed"}
        
        # Unified loading of POI data to poi_analyzer
        self.poi_analyzer.load_pool_from_dict(poi_dict_json)
        
        # Extract model answer
        # prompt_str, answer_text, processed_str = ItineraryParser.extract_solution(solution_str)
        
        # Prepare evaluation data
        evaluation_data = {
            "solution_str": solution_str,
            "poi_dict": poi_dict_json,
            # "prompt_str": prompt_str,
            # "answer_text": answer_text
        }
        
        # Execute format evaluation
        format_score, format_details, prompt_str, answer_text = self.format_evaluator.evaluate(evaluation_data)

        evaluation_data["prompt_str"] = prompt_str
        evaluation_data["answer_text"] = answer_text

        # Check if the format is correct
        format_correct = format_details.get("format_correct", True)
        
        # If the format is incorrect, return directly
        if not format_correct:
            total_score = format_score + (-2)  # format_score (-1) + answer_score (-2) = -3
            route_score, route_details = -2, {"error": "Format validation failed"}
            transportation_score, transportation_details = 0, {"error": "Skipped due to format errors"}
            commonsense_score, commonsense_details = 0, {"error": "Skipped due to format errors"}
            soft_constraint_score, soft_constraint_details = 0, {"error": "Skipped due to format errors"}
            preference_score, preference_details = 0, {"error": "Skipped due to format errors"}
            user_request_score, user_request_details = 0, {"error": "Skipped due to format errors"}
        else:
            # Get the result data of the format evaluation
            itinerary = format_details.get("itinerary", {})
            
            # Prepare data for other evaluators
            evaluation_data.update({
                "itinerary": itinerary,
                "poi_dict": poi_dict_json  # itinerary is the same as all_content from format evaluator
            })

            # Execute commonsense constraint evaluation
            commonsense_score, commonsense_details = self.commonsense_evaluator.evaluate(evaluation_data)
            
            # Initialize variables
            soft_constraint_score, soft_constraint_details = 0, {"error": "Not evaluated"}
            preference_score, preference_details = 0, {"error": "Not evaluated"}
            user_request_score, user_request_details = 0, {"error": "Not evaluated"}
            
            # Check if the commonsense constraint is passed
            commonsense_passed = commonsense_score >= 0.9  # Assume 0.5 as the passing threshold
            
            if not commonsense_passed:
                # If the commonsense constraint is not passed, return score 0
                total_score = 0
                soft_constraint_score, soft_constraint_details = 0, {"error": "Skipped due to commonsense violations"}
                preference_score, preference_details = 0, {"error": "Skipped due to commonsense violations"}
                user_request_score, user_request_details = 0, {"error": "Skipped due to commonsense violations"}
            else:
                # If the commonsense constraint is passed, continue to execute other evaluations
                # Execute soft constraint evaluation
                soft_constraint_score, soft_constraint_details = self.soft_constraint_evaluator.evaluate(evaluation_data)

                # Execute user request constraint evaluation (if enabled)
                if self.enable_user_request_eval:
                    user_request_score, user_request_details = self.user_request_evaluator.evaluate(evaluation_data)
                else:
                    # Execute personal preference constraint evaluation
                    preference_score, preference_details = self.preference_evaluator.evaluate(evaluation_data)
                
                # Calculate total score: all scores added
                total_score = format_score + commonsense_score + 1.0 * soft_constraint_score + 0.1 * preference_score + 1.4 * user_request_score
        
        # Summarize evaluation results
        evaluation_results = {
            "format_score": format_score,
            "commonsense_score": commonsense_score,
            "soft_constraint_score": soft_constraint_score,
            "preference_score": preference_score,
            "user_request_score": user_request_score,
            "total_score": total_score,
            "format_details": format_details,
            "commonsense_details": commonsense_details,
            "soft_constraint_details": soft_constraint_details,
            "preference_details": preference_details,
            "user_request_details": user_request_details
        }
        
        return total_score, evaluation_results
    
    def _parse_poi_dict(self, poi_dict: str) -> Dict[str, Any]:
        """
        Parse POI dictionary
        
        Args:
            poi_dict: POI dictionary string
            
        Returns:
            Parsed POI dictionary
        """
        import json
        
        if isinstance(poi_dict, dict):
            return poi_dict
            
        try:
            return json.loads(poi_dict)
        except Exception as e:
            print(f"POI dictionary parsing failed: {e}")
            return None
    
    def add_evaluator(self, evaluator: BaseEvaluator):
        """
        Add new evaluator
        
        Args:
            evaluator: Evaluator instance
        """
        self.evaluators.append(evaluator)
    
    def get_evaluator_names(self) -> List[str]:
        """
        Get all evaluator names
        
        Returns:
            Evaluator name list
        """
        return [evaluator.get_name() for evaluator in self.evaluators] 