import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, get_args
from dataclasses import dataclass
import datetime



from agent.ttg_agent.instruction_translator import InstructionTranslator
from evaluators.main_evaluator import MainEvaluator
from utils.llms import AbstractLLM
from utils.poi_analyzer import POIAnalyzer
from .milp_solver import MILPSolver
from .data_model import Itinerary, DayInfo, ScheduleDetail, DetailItem, OptimizationObjective, TravelRequest

class TTGAgent:
    """Main TTG system combining all components"""

    def __init__(self, **kwargs):
        """
        Args:
            kwargs: 
                - model: AbstractLLM
                - 
        """
        self.poi_analyzer: POIAnalyzer = kwargs.get("poi_analyzer", POIAnalyzer())
        self.main_evaluator: MainEvaluator = kwargs.get("main_evaluator", MainEvaluator(poi_analyzer=self.poi_analyzer))
        kwargs["poi_analyzer"] = self.poi_analyzer
        kwargs["main_evaluator"] = self.main_evaluator


        self.translator = InstructionTranslator(**kwargs)
        self.solver = MILPSolver(**kwargs)

    def run(self, query: Dict[str, Any], load_cache: bool = False, **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """
        Main entry point for TTG agent, compatible with experiment framework
        
        Args:
            query: Query dictionary containing travel request
            load_cache: Whether to load from cache (not used in TTG)
            **kwargs: Additional arguments
            
        Returns:
            Tuple of (success, plan_dict)
        """
        try:
            # Extract user query from the query dictionary
            user_query = query.get("userQuery", "")
            if not user_query:
                return False, {"error": "No user query provided"}
            self.poi_analyzer.load_pool_from_dict(query)
            self.solver.poi_analyzer = self.poi_analyzer
            self.solver.locale = query.get("locale", "zh-CN")
            self.solver.departure = query.get("departure", "")
            self.solver.arrive = query.get("arrive", "")
            # Process the natural language request
            plan = self.process_request(user_query)

            final_plan = {'llm_response': json.dumps(plan,ensure_ascii=False),
                          'agent_type': 'TTG',
                          'optimization_objective': 'user_preference'
                          }
            
            # Ensure plan is in JSON format
            if isinstance(plan, dict):
                return True, final_plan
            else:
                # Convert to dict if it's not already
                return True, final_plan
                
        except Exception as e:
            return False, {"error": f"TTG processing failed: {str(e)}"}

    def process_request(self, natural_language_request: str) -> Dict[str, Any]:
        """
        Main pipeline:
        1. Translate NL request to symbolic form
        2. Get available flights/hotels
        3. Solve for multiple optimization objectives
        4. Return JSON format itinerary
        """
        try:
            symbolic_request = self.translator.natural_to_symbolic(natural_language_request)
            # according to the transportation_pool order, modify destination_cities_ordered
            target_cities = [transport["key"].split("->")[-1] for transport in self.poi_analyzer.transportation_info]
            symbolic_request.destination_cities_ordered = self.optimize_city_route(self.solver.departure, target_cities)

            itineraries = self.process_symbolic_request(symbolic_request)
            
            # Get the user preference itinerary and convert to JSON format
            if 'user_preference' in itineraries and itineraries['user_preference']:
                itinerary = itineraries['user_preference']
                # Convert to dictionary format
                if hasattr(itinerary, 'to_dict'):
                    plan_dict = itinerary.to_dict(itinerary)
                else:
                    # Fallback: try to convert using asdict if it's a dataclass
                    from dataclasses import asdict
                    try:
                        plan_dict = asdict(itinerary)
                    except TypeError:
                        # If not a dataclass, try basic dict conversion
                        plan_dict = dict(itinerary) if hasattr(itinerary, '__dict__') else str(itinerary)

                return plan_dict
            else:
                return {"error": "No valid itinerary generated"}
                
        except Exception as e:
            return {"error": f"Failed to process request: {str(e)}"}
    
    def process_symbolic_request(self, symbolic_request: TravelRequest) -> Dict[OptimizationObjective, Itinerary]:
        """
        Process a symbolic travel request directly without translation
        
        Args:
            symbolic_request: TravelRequest object with all necessary information
            
        Returns:
            Dictionary of itineraries optimized for different objectives
        """
        itineraries = {}
        
        number_of_days = (symbolic_request.end_date - symbolic_request.start_date).days + 1
        destination_cities = symbolic_request.destination_cities_ordered
        itinerary_title = f"{number_of_days} day travel to {', '.join(destination_cities)}"
        
        # Update solver with the request
        self.solver.request = symbolic_request
        
        for objective in get_args(OptimizationObjective):
            try:
                itinerary = self.solver.solve(objective)
                itinerary.itineraryName = f"{itinerary_title} optimized for {objective}"
                itineraries[objective] = itinerary
            except Exception as e:
                print(f"Failed to solve for objective {objective}: {e}")
                continue
                
        return itineraries
    
    def evaluate_solution_quality(self, itinerary: Itinerary, poi_dict: Dict[str, Any]) -> Tuple[float, Dict[str, Any]]:
        """Calculate solution quality ratio"""
        itinerary_dict = Itinerary.to_dict(itinerary) if isinstance(itinerary, Itinerary) else itinerary
        itinerary_json = json.dumps(itinerary_dict, ensure_ascii=False)
        
        # Create a solution string that mimics a model response
        # The MainEvaluator expects a string with Assistant: prefix and <answer> tags
        solution_str = f"Assistant:\n<answer>\n{itinerary_json}\n</answer>"
        
        return self.main_evaluator.compute_score(solution_str, poi_dict)
    
    def optimize_city_route(self, source_city: str, target_cities: List[str]) -> List[str]:
        """
        optimize the city visit route, considering the geographical location and transportation convenience
        """
        if len(target_cities) <= 1:
            return target_cities

        optimized_route = []
        remaining_cities = target_cities.copy()
        current_city = source_city
        if source_city in remaining_cities:
            remaining_cities.remove(source_city)
            optimized_route.append(source_city)
        
        max_iter = 100
        while remaining_cities:
            # find the nearest next city to the current city
            best_city = None
            best_score = 0.7
            
            for city in remaining_cities:
                # check if there is a direct transportation connection
                transport_info = self.collect_intercity_transport(current_city, city, "train")
                
                # calculate the comprehensive score (transport convenience + geographical location)
                transport_score = 1.0 if transport_info else 0.5
                
                if transport_score > best_score:
                    best_score = transport_score
                    best_city = city
                if best_city:
                    optimized_route.append(best_city)
                    remaining_cities.remove(best_city)
                    current_city = best_city
            max_iter -= 1
            if max_iter <= 0:
                break
        if not optimized_route or max_iter <0:
            return remaining_cities
        
        return optimized_route
    
    def collect_intercity_transport(self, source_city:str, target_city:str, trans_type:str)->List[Dict]:
        """
        return the corresponding transportation information based on the source city, target city and transportation type
        """
        trans_type_mapping = {  
            "bus": "B",
            "train": "T",
            "flight": "F",
            "drive": "SC",
            "driving": "D",
            "ship": "S",
        }
        trans_info = []
        trans_type = trans_type
        # iterate over self.memory["transports"], find the transportation_id corresponding to the source city and target city
        for transport in self.poi_analyzer.transportation_info:
            from_city = transport["key"].split("->")[0]
            to_city = transport["key"].split("->")[1]
            if from_city == source_city and to_city == target_city:
                trans_info.append(transport)

        return trans_info