import json
import os
import traceback
from typing import Dict, Any, Tuple, List

from utils.llms import AbstractLLM
from evaluators.main_evaluator import MainEvaluator

from utils.itinerary_parser import ItineraryParser
from utils.poi_analyzer import POIAnalyzer

from agent.llm_modulo.logger import Logger
from agent.llm_modulo.build_prompt import _build_message_from_reference, _build_refinement_messages

class LLMModuloAgent:
    """LLM-modulo Agent for trip planning with iterative refinement"""
    
    def __init__(self, **kwargs):
        self.env: str = kwargs.get("env")
        self.cache_dir: str = kwargs.get("cache_dir")
        self.log_dir: str = kwargs.get("log_dir")
        self.debug: bool = kwargs.get("debug", False)
        self.max_steps: int = kwargs.get("max_steps", 10)
        self.model: AbstractLLM = kwargs.get("model") 
        self.main_evaluator: MainEvaluator = kwargs.get("main_evaluator")
        self.tokenizer = kwargs.get("tokenizer")
        self.poi_analyzer = kwargs.get("poi_analyzer")
        self.logger = Logger(self.log_dir, self.debug)
        
    def solve(self, query: Dict[str, Any], prob_idx: str = None) -> Tuple[bool, Dict[str, Any]]:
        """
        Solve the trip planning problem using LLM-modulo approach
        
        Args:
            query: The query data containing travel requirements
            prob_idx: Problem index for logging
            oracle_verifier: Whether to use oracle verification
            
        Returns:
            tuple: (success, plan)
        """
        try:
            cur_enable_LLM = self.main_evaluator.soft_constraint_evaluator.enable_LLM
            cur_enable_userrequst = self.main_evaluator.enable_user_request_eval

            self.main_evaluator.soft_constraint_evaluator.enable_LLM = False
            self.main_evaluator.enable_user_request_eval = False
            self.logger.info(f"Starting LLM-modulo agent for problem {prob_idx}")
            messages = query.get("messages", [])
            if isinstance(messages, str):
                import json
                try:
                    messages = json.loads(messages)
                except json.JSONDecodeError:
                    messages = []
            poi_dict = self._build_poi_dict(query)
            current_plan = self._generate_initial_plan(messages, poi_dict)
            best_plan = current_plan
            self.logger.info(f"Initial plan: {current_plan}")
            score, results = self._evaluate_plan(messages, poi_dict, current_plan)
            self.logger.info(f"Score: {score}")
            self.logger.info(f"Evaluation results:")
            self.logger.info(f"format_details: {results['format_details']['validation_results']}")
            self.logger.info(f"commonsense_details: {results['commonsense_details']}")
            self.logger.info(f"soft_constraint_details: {results['soft_constraint_details']}")
            self.logger.info(f"preference_details: {results['preference_details']}")
            self.logger.info(f"user_request_details: {results['user_request_details']}")

            if score >= 3.5:
                self.logger.info(f"Valid plan found at initial generation")
                return_result = {
                    "method": "llm-modulo",
                    "llm_response": best_plan or {"error": "Failed to generate plan"},
                    "messages": messages,
                    "best_score": score,
                    "status": "completed"
                }
                return True, return_result
            best_score = score
            
            for step in range(1, self.max_steps + 1):
                self.logger.info(f"Step {step}/{self.max_steps}")
                current_plan = self._refine_plan(messages, poi_dict, current_plan, results)
                self.logger.info(f"Plan {step}: {current_plan}")
                score, results = self._evaluate_plan(messages, poi_dict, current_plan)
                best_score = max(best_score, score)
                self.logger.info(f"Score: {score}")
                self.logger.info(f"Evaluation results:")
                self.logger.info(f"format_details: {results['format_details']['validation_results']}")
                self.logger.info(f"commonsense_details: {results['commonsense_details']}")
                self.logger.info(f"soft_constraint_details: {results['soft_constraint_details']}")
                self.logger.info(f"preference_details: {results['preference_details']}")
                self.logger.info(f"user_request_details: {results['user_request_details']}")

                if score >= best_score:
                    best_plan = current_plan
                if score >= 3.5:
                    self.logger.info(f"Valid plan found at step {step}")
                    break
                    
            success = best_plan is not None
            return_result = {
                "method": "llm-modulo",
                "llm_response": best_plan or {"error": "Failed to generate plan"},
                "messages": messages,
                "best_score": best_score,
                "status": "completed"
            }
            self.main_evaluator.soft_constraint_evaluator.enable_LLM = cur_enable_LLM
            self.main_evaluator.enable_user_request_eval = cur_enable_userrequst
            return success, return_result
            
        except Exception as e:
            self.logger.error(f"LLMModuloAgent error: {e}")
            traceback.print_exc()
            return False, {"error": str(e) + " " + traceback.format_exc()}
    
    def _build_poi_dict(self, query: Dict[str, Any]) -> Dict[str, Any]:
        """Build poi_dict from query data"""
        return {
            "userQuery": query.get("userQuery", ""),
            "day": query.get("day", -1),
            "locale": query.get("locale", "en-US"),
            "departure": query.get("departure", ""),
            "arrive": query.get("arrive", ""),
            "transportation": query.get("transportation", "Yes"),
            "reference": query.get("reference", ""),
            "transport_pool": query.get("transport_pool", "{}"),
            "preference": query.get("preference", "{}"),
            "poi_pool": query.get("poi_pool", ""),
            "hotel_pool": query.get("hotel_pool", "")
        }
    
    def _generate_initial_plan(self, messages, poi_dict: Dict[str, Any]) -> str:
        """Generate initial trip plan"""
        response = self.model._get_response(messages, one_line=False, json_mode=False)
        return response
    
    def _refine_plan(self, messages, poi_dict: Dict[str, Any], current_plan: str, evaluation_results: Dict[str, Any]) -> str:
        """Refine the current plan"""
        refinement_messages = self._get_refinement_prompt(messages, poi_dict, current_plan, evaluation_results)
        response = self.model._get_response(refinement_messages, one_line=False, json_mode=False)
        return response
    
    def _evaluate_plan(self, messages, poi_dict: Dict[str, Any], plan: str) -> Tuple[float, Dict[str, Any]]:
        """Verify if the plan is valid"""

        prompt_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        solution_str = prompt_text + plan
        score, results = self.main_evaluator.compute_score(solution_str, poi_dict)
        return score, results
    
    def _parse_itinerary_from_plan(self, messages, current_plan: str, poi_dict: Dict[str, Any]) -> Dict[str, Any]:
        """Parse itinerary from LLM-generated plan string"""
        try:
            prompt_text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False
            )
            solution_str = prompt_text + current_plan
            prompt_str, answer_text, processed_str = ItineraryParser.extract_solution(solution_str)
            validation_result, itinerary = ItineraryParser.extract_itinerary(answer_text, poi_dict, self.poi_analyzer)
            return itinerary or {}
        except Exception:
            return {}
    
    def _get_refinement_prompt(self, messages, poi_dict: Dict[str, Any], current_plan: str, evaluation_results: Dict[str, Any]) -> List[Dict[str, str]]:
        """Create refinement prompt"""
        current_itinerary = self._parse_itinerary_from_plan(messages, current_plan, poi_dict)
        messages = _build_refinement_messages(poi_dict, current_plan, current_itinerary, evaluation_results)
        return messages
