"""
Itinerary Parser - Used to parse and verify itinerary data
"""

import re
import json
from typing import Dict, Any, Tuple, Optional, List
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
from datetime import datetime


from .entity_utils import EntityUtils
from .transportation_parser import TransportationParser
from .poi_analyzer import POIAnalyzer

class ItineraryParser:
    """Itinerary Parser Class - Optimized Version"""
    
    # Constant Definition
    DEFAULT_LOCALE = "zh-cn"
    DATETIME_FORMAT = '%Y-%m-%d %H:%M'
    
    def __init__(self):
        """Initialize Parser"""
        self.transportation_parser = TransportationParser()

    """Itinerary Parser Class - Merged Version, including original functionality and new parsing methods"""

    @staticmethod
    def __parse_transportation(trans: Dict[str, Any], poi_dict) -> Dict[str, Any]:
        """
        Parse transportation data - Original method
        Args:
            transportation: Transportation Dictionary
        Returns:
            Dict[str, Any]: Parsed transportation data
        """
        trans['planid'] = trans['tripId4V1Hash']
        segements = [TransportationParser.parse_segment(segment) for segment in trans['segments']]
        dep_time = segements[0]["departureTime"]
        flight_time = segements[-1]["arrivalTime"]
        for segment in segements:
            cur_dep_time = segment['departureTime']
            cur_arrival_time = segment['arrivalTime']
            if dep_time > cur_dep_time:
                dep_time = cur_dep_time
            if flight_time < cur_arrival_time:
                flight_time = cur_arrival_time
        t_type = ""
        if len(trans.get("from",'')) != 0:
            if trans['from']['name'].lower() == poi_dict['departure'].lower():
                t_type = "to"
            elif trans['to']['name'].lower() == poi_dict['departure'].lower():
                t_type = "back"

        trans['flight_time'] = flight_time
        trans['depature_time'] = dep_time
        trans["segments"] = segements
        trans['t_type'] = t_type
        return trans

    def parse_itinerary(self, full_json: Dict, transportation_dict: Dict) -> List[Dict]:
        """
        Parse itinerary data - New method, used for API evaluation
        
        Args:
            full_json: Full itinerary data
            transportation_dict: Transportation information dictionary
            
        Returns:
            Parsed itinerary list
        """
        schedules = []
        
        for dayItem in full_json['dayInfos']:
            itemList = []
            used_poiids = set()
            day_title = dayItem['scheduleTitle']
            scheduleDetail = dayItem["scheduleDetail"]
            
            for schedule in scheduleDetail:
                period = schedule["period"]
                detailList = schedule["detailList"]
                for item in detailList:
                    if item is None:
                        continue
                    
                    parsed_item = self._parse_item(item, period, day_title, transportation_dict, used_poiids)
                    if parsed_item:
                        itemList.append(parsed_item)
                        if item['type'] == 'poi':
                            used_poiids.add(item['id'])
            
            schedules.append({'day': str(dayItem['day']), 'itemList': itemList})
        
        return schedules
    
    def _parse_item(self, item: Dict, period: str, day_title: str, 
                   transportation_dict: Dict, used_poiids: set) -> Dict:
        """
        Parse single itinerary item
        
        Args:
            item: Itinerary item data
            period: Period
            day_title: Date title
            transportation_dict: Transportation information dictionary
            used_poi_ids: Used POI ID set
            
        Returns:
            Parsed itinerary item
        """
        if item['type'] == 'poi':
            if item['id'] in used_poiids or item['id'] == "":
                return None
            return {
                'type': 'attraction', 
                'id': item['id'], 
                'name': item['name'], 
                'period': period
            }
        elif item['type'] == 'lodgingArea':
            return {
                'type': 'hotel', 
                'id': item['hotelList'][0]['id'], 
                'name': item['hotelList'][0]['name'], 
                'period': period
            }
        elif item['type'] == 'hotel':
            return {
                'type': 'hotel', 
                'id': item['id'], 
                'name': item['name'], 
                'period': period
            }
        elif item['type'] == 'transportation':
            return self._parse_transportation_item(item, period, day_title, transportation_dict)
        
        return None
    
    def _parse_transportation_item(self, item: Dict, period: str, day_title: str, 
                                 transportation_dict: Dict) -> Dict:
        """
        Parse transportation item
        
        Args:
            item: Transportation item data
            period: Period
            day_title: Date title
            transportation_dict: Transportation information dictionary
            
        Returns:
            Parsed transportation item
        """
        transport_info = transportation_dict.get(str(item['id']))
        if transport_info is None:
            return None
            
        first_seg = transport_info['segments'][0]
        trans_f_seg = TransportationParser.parse_segment(first_seg)
        last_seg = transport_info['segments'][-1]
        trans_l_seg = TransportationParser.parse_segment(last_seg)
        minutes = transport_info['minutes']

        return {
            'type': 'transportation', 
            'id': item['id'], 
            'name': item['name'],
            'start_time': datetime.strptime(trans_f_seg['departureTime'], '%Y-%m-%d %H:%M').strftime('%Y-%m-%d %H:%M'),
            'end_time': datetime.strptime(trans_l_seg['arrivalTime'], '%Y-%m-%d %H:%M').strftime('%Y-%m-%d %H:%M'),
            'period': period,
            'minutes': minutes,
            'day_title': day_title,
            'locale': "zh-cn"
        }
    

    @staticmethod
    def extract_solution(solution_str: str) -> Tuple[str, str, str]:
        """Extracts the final answer from the model's response string.
        Handles multi-turn conversations by splitting on the last assistant marker.

        Args:
            solution_str: Raw response string from the language model

        Returns:
            Tuple containing (prompt_str, final_answer, processed_string)
        """
        # Split response to isolate the final assistant output
        # The order of checks is important, from most specific to most generic.
        if "<|im_start|>assistant" in solution_str:
            prompt_str, processed_str = solution_str.rsplit("<|im_start|>assistant", 1)
        elif "<｜Assistant｜>" in solution_str:
            prompt_str, processed_str = solution_str.rsplit("<｜Assistant｜>", 1)
        elif "Assistant:" in solution_str:
            prompt_str, processed_str = solution_str.rsplit("Assistant:", 1)
        else:
            print("[Error] Failed to locate model response header")
            return solution_str, None, solution_str
        
        processed_str = processed_str.replace("<|im_end|>", "")
        # Extract final answer using XML-style tags
        answer_pattern = r'<answer>(.*?)</answer>'
        matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))

        if not matches:
            # print("[Error] No valid answer tags found")
            return prompt_str, processed_str, processed_str

        final_answer = matches[-1].group(1).strip()
        final_answer = re.sub(r'<think>.*?</think>', '', final_answer, flags=re.DOTALL)
        return prompt_str, final_answer, processed_str

    @staticmethod
    def validate_response_structure(processed_str: str, poi_dict: dict, poi_analyzer: POIAnalyzer) -> Tuple[bool, Dict, Dict]:
        """
        Validate response structure
        
        Args:
            processed_str: Processed response string
            poi_dict: POI dictionary
            poi_analyzer: POI analyzer
            
        Returns:
            Tuple[bool, Dict, Dict]: (Validation passed, All content, Itinerary information)
        """
        validation_passed = True

        # Remove think tags
        processed_str = re.sub(r'<think>.*?</think>', '', processed_str, flags=re.DOTALL)

        # Extract itinerary information
        validation_result, itinerary = ItineraryParser.extract_itinerary(processed_str, poi_dict, poi_analyzer)

        if "fail" in validation_result:
            validation_passed = False

        return validation_passed, validation_result, itinerary
    
    @staticmethod
    def extract_itinerary(answer: str, poi_dict: Dict[str, Any], poi_analyzer: POIAnalyzer) -> Tuple[Dict, Dict]:
        """
        Extract itinerary information from answer
        
        Args:
            answer: Answer string
            poi_dict: POI dictionary
            poi_analyzer: POI analyzer
            
        Returns:
            Tuple[Dict, Dict]: (Content JSON, Itinerary information)
        """
        # Use regular expression to extract the content of ```json (.*?)```
        json_pattern = r'```json\s*(.*?)\s*```'
        json_match = re.search(json_pattern, answer, re.DOTALL)
        if json_match:
            answer_processed = json_match.group(1).strip()
        else:
            # If no ```json``` format is found, fall back to the original method
            answer_processed = answer.replace("```json", "").replace("```", "").strip()
        try:
            # Extract JSON content
            content_json = json.loads(answer_processed)
            
            if "dayInfos" not in content_json:
                return {"fail": {"Response Format": "no dayInfos in json"}}, {}
            
            # Validate the number of days
            day_infos = content_json["dayInfos"]
            expected_days = poi_dict.get("day", 0)
            if int(expected_days) != 0 and len(day_infos) != int(expected_days):
                return {"fail": {"Response Format": f"day number wrong {len(day_infos)} | {expected_days}"}}, {}

            try:
                for dayinfo in day_infos:
                    dayinfo["day"] = int(dayinfo["day"])
            except Exception as e:
                return {"fail": {"Response Format": f"day in dayInfos cant be load in int format {e}"}}, {}
            # Validate itinerary content
            validation_result = ItineraryParser._validate_itinerary_content(day_infos, poi_dict, poi_analyzer)
            if validation_result:
                return validation_result, {}


            return {}, content_json
            
        except json.JSONDecodeError as e:
            print("answer_processed", answer_processed)
            return {"fail": {"Response Format": f"json load fail {e}"}}, {}
        except Exception as e:
            return {"fail": {"Response Format": f"unexpected error {e}"}}, {}
    
    @staticmethod
    def _validate_itinerary_content(day_infos: List[Dict[str, Any]], poi_dict: Dict[str, Any], 
                                  poi_analyzer: POIAnalyzer) -> Optional[Dict[str, Any]]:
        """Validate itinerary content"""
        try:
            locale = poi_dict.get("locale", "zh-CN")
            
            for day_index, day_info in enumerate(day_infos, 1):
                expected_day = day_info.get("day")
                if str(day_index) != str(expected_day):
                    return {"fail": {"Response Format": f"day_index number wrong {expected_day} | {day_index}"}}
                
                schedule_detail = day_info.get("scheduleDetail", [])
                if not schedule_detail:
                    return {"fail": {"Response Format": "not provide scheduleDetail"}}
                
                for schedule in schedule_detail:
                    validation_result = ItineraryParser._validate_schedule(schedule, poi_analyzer, locale)
                    if validation_result:
                        return validation_result
            
            return None
            
        except Exception as e:
            # print("error ", str(e), "answer_processed", answer_processed)
            return {"fail": {"Response Format": str(e)}}
    
    @staticmethod
    def _validate_transportation_completeness(day_infos: List[Dict[str, Any]], poi_dict: Dict[str, Any], 
                                            poi_analyzer: POIAnalyzer) -> Optional[Dict[str, Any]]:
        """Validate transportation completeness"""
        try:
            traffic_set = set()
            for day_info in day_infos:
                for schedule in day_info.get("scheduleDetail", []):
                    for detail in schedule.get("detailList", []):
                        if detail and detail.get("type") == "transportation":
                            detail_id = detail.get("id")
                            if detail_id:
                                traffic_set.add(detail_id)

            arrive = poi_dict.get("arrive", "")
            if isinstance(arrive, str): 
                arrives = arrive.split(",")
                final_arrives = []
                for a in arrives:
                    is_duplicate = False
                    for f in final_arrives:
                        if a.strip().lower() in f.strip().lower() or f.strip().lower() in a.strip().lower():
                            is_duplicate = True
                            break
                    if not is_duplicate:
                        final_arrives.append(a)
                
                arrive_num = len(final_arrives) - 1
                if arrive_num < 0:
                    arrive_num = 0
                
                transportation_pool = poi_analyzer.transportation_info
                reference_text = str(poi_dict.get("reference", ""))
                self_transport_keywords = ["on their own, so", "一天自行", "No transportation arrangements needed"]
                is_self_transport = any(keyword in reference_text for keyword in self_transport_keywords)
                
                required_traffic_num = 2 + arrive_num

                if (poi_dict.get("transportation") == "Yes" and
                    len(transportation_pool) > 0 and 
                    len(traffic_set) < required_traffic_num and 
                    len(traffic_set) < len(transportation_pool) and 
                    not is_self_transport):
                    
                    fail_message = (
                        f"Insufficient transportation arrangements. Found {len(traffic_set)} transportation items, "
                        f"but {required_traffic_num} are required. Available transportation options: {len(transportation_pool)}. "
                        f"Transportation items used: {traffic_set}. Reference context: {reference_text[:900]}"
                    )
                    return {"fail": {"Information Completeness": fail_message}}
            
            return None
        except Exception as e:
            return {"fail": {"Information Completeness Check Error": str(e)}}

    @staticmethod
    def _validate_hotel_completeness(day_infos: List[Dict[str, Any]], poi_dict: Dict[str, Any], poi_analyzer: POIAnalyzer) -> Optional[Dict[str, Any]]:
        """Validate hotel completeness, determine if hotels are needed based on transportation time"""
        try:
            hotel_set = set()
            traffic_day = set()
            real_traffic_day = []
            actual_day = len(day_infos)
            
            # Collect hotel and transportation information
            for day_index, day_info in enumerate(day_infos, 1):
                for schedule in day_info.get("scheduleDetail", []):
                    for detail in schedule.get("detailList", []):
                        if not detail:
                            continue
                            
                        detail_type = detail.get("type", "")
                        detail_id = detail.get("id", "")
                        
                        if detail_type == "hotel" and detail_id:
                            # Use poi_analyzer's all_hotel_info to validate if the hotel ID exists
                            if detail_id in poi_analyzer.hotel_pool:
                                hotel_set.add(detail_id)
                        elif detail_type == "transportation" and detail_id:
                            traffic_day.add(day_index)
                            # Use poi_analyzer's transportation_info to validate if the transportation ID exists
                            transportation_exists = any(
                                transport.get("planid") == detail_id 
                                for transport in poi_analyzer.transportation_info
                            )
                            if transportation_exists:
                                real_traffic_day.append([day_index, detail_id])

            # Calculate the number of days requiring hotels
            need_hotel_day = []
            for i in range(1, actual_day-1):
                cur_traffic_set = set()
                for real_traffic in real_traffic_day:
                    if i+1 == real_traffic[0]:
                        cur_traffic_set.add(real_traffic[1])
                if len(cur_traffic_set) == 1:
                    need_hotel_day.append(i+1)

            # Calculate the number of destinations
            arrive = poi_dict.get("arrive", "")
            if not isinstance(arrive, str):
                arrive = ""
                
            arrives = arrive.split(",")
            final_arrives = []
            for a in arrives:
                a_stripped = a.strip()
                if not a_stripped:
                    continue
                is_duplicate = False
                for f in final_arrives:
                    if a_stripped.lower() in f.lower() or f.lower() in a_stripped.lower():
                        is_duplicate = True
                        break
                if not is_duplicate:
                    final_arrives.append(a_stripped)
            
            arrive_num = len(final_arrives) - 1
            if arrive_num < 0:
                arrive_num = 0

            # Validate hotel completeness
            if (actual_day > 1 and 
                len(hotel_set) < len(need_hotel_day) + 1 and 
                len(hotel_set) < arrive_num):
                
                fail_message = (
                    f"Insufficient hotel accommodations. Days requiring hotels: {need_hotel_day}, "
                    f"but only {len(hotel_set)} hotel(s) found. Number of destinations: {arrive_num}. "
                    f"Destinations: {final_arrives}. Hotels provided: {hotel_set}"
                )
                return {"fail": {"Information Completeness": fail_message}}

            return None

        except Exception as e:
            print("Information Completeness Check Error", str(e))
            return None

    @staticmethod
    def _validate_schedule(schedule: Dict[str, Any], poi_analyzer: POIAnalyzer, locale: str) -> Optional[Dict[str, Any]]:
        """Validate single itinerary arrangement"""
        try:
            period = schedule.get("period", "")
            detail_list = schedule.get("detailList", None)
            
            if detail_list is None:
                return {"fail": {"Response Format": "not provide detailList"}}
            
            poi_ids = set()
            hotel_ids = set()
            
            for detail in detail_list:
                if not detail or "type" not in detail:
                    return {"fail": {"Response Format": "not provide type"}}
                if "id" not in detail:
                    return {"fail": {"Response Format": "not provide id"}}

                detail_type = detail["type"]
                detail_id = str(detail.get("id", ""))
                detail_name = str(detail.get("name", ""))
                
                if detail_type == "poi":
                    if detail_id:
                        # Use poi_analyzer's all_poi_info to validate if the POI ID exists
                        if detail_id not in poi_analyzer.poi_pool:
                            return {"fail": {"Information Verification": f"poi fake id: {detail_id}"}}
                        poi_id_num = detail_id.split('_')[-1]
                        poi_api_info, _, _ = poi_analyzer.read_poi_api_info(int(poi_id_num), locale)

                        poi_ids.add(detail_id)
                        is_mismatch, api_name, itinerary_name = poi_analyzer.check_poi_id_name_match(poi_api_info, detail_name, locale)
                        if is_mismatch:
                            return {"fail": {"Information Accuracy": f"POI name mismatch: itinerary name is '{itinerary_name}', but API name for ID '{detail_id}' is '{api_name}'."}}
                    else:
                        if detail_name:
                            if detail_name in poi_analyzer.poi_name_pool:
                                return {"fail": {
                                    "Information Accuracy": f"POI hacking: itinerary name is '{detail_name}', but API ID is empty."}}

                elif detail_type == "hotel":
                    if detail_id:
                        # Use poi_analyzer's all_hotel_info to validate if the hotel ID exists
                        if detail_id not in poi_analyzer.hotel_pool:
                            return {"fail": {"Information Verification": f"hotelid fake id: {detail_id}"}}
                        
                        hotel_ids.add(detail_id)
                        hotel_info = poi_analyzer.get_hotel_info(detail_id, locale)
                        is_mismatch, api_name, itinerary_name = poi_analyzer.check_hotel_id_name_match(hotel_info, detail_name, locale)
                        if is_mismatch:
                            return {"fail": {"Information Accuracy": f"Hotel name mismatch: itinerary name is '{itinerary_name}', but API name for ID '{detail_id}' is '{api_name}'."}}
                    else:
                        if detail_name:
                            if detail_name in poi_analyzer.hotel_name_pool:
                                return {"fail": {
                                    "Information Accuracy": f"Hotel hacking: itinerary name is '{detail_name}', but API ID is empty."}}

                elif detail_type == "transportation":
                    if detail_id and len(detail_id) >= 5 and 'test' not in detail_id:
                        # Use poi_analyzer's transportation_info to validate if the transportation ID exists
                        transportation_exists = any(
                            transport.get("planid") == detail_id 
                            for transport in poi_analyzer.transportation_info
                        )
                        if not transportation_exists:
                            return {"fail": {"Information Verification": f"planid fake id: {detail_id}"}}
                        
                        is_mismatch, api_period, itinerary_period = poi_analyzer.check_transportation_id_period_match(detail_id, period, locale)
                        if is_mismatch:
                            return {"fail": {"Information Accuracy": f"Transportation time period mismatch: itinerary period is '{itinerary_period}', but API period for ID '{detail_id}' should be '{api_period}'"}}

                else:
                    return {"fail": {"Response Format": f"type not in enum: {detail_type}"}}
            
            # Validate description
            description = schedule.get("description", None)
            if description is None:
                return {"fail": {"Response Format": "not provide description"}}
            
            # Validate entities in description
            validation_result = ItineraryParser._validate_description_entities(
                description, poi_ids, hotel_ids
            )
            if validation_result:
                return validation_result

            return None
            
        except Exception as e:
            return {"fail": {"Response Format": str(e)}}
    
    @staticmethod
    def _validate_description_entities(description: str, poi_ids, hotel_ids) -> Optional[Dict[str, Any]]:
        """Validate entities in description"""
        try:
            pattern = r"\[([^\]]+)\]\(([^)]+)\)"
            entity_matches = re.findall(pattern, description)
            
            for match in entity_matches:
                entity_name = match[0].strip()
                entity_id = match[1].strip()
                entity_parts = entity_id.split("_")
                
                if len(entity_parts) != 2 or entity_parts[0] not in ["poi", "hotel"]:
                    continue
                
                entity_type, final_id = entity_parts
                
                if entity_type == "poi" and final_id not in poi_ids and poi_ids:
                    return {"fail": {"Information Relevance": f"description poi not align poi_ids description: {description} poi_ids: {poi_ids}"}}
                elif entity_type == "hotel" and final_id not in hotel_ids and hotel_ids:
                    return {"fail": {"Information Relevance": f"description hotel not align hotel_ids description: {description} hotel_ids: {hotel_ids}"}}
            
            # Return result
            return  None
        except Exception as e:
            import traceback
            traceback.print_exc()
            return None
