"""
POI Analyzer - Responsible for handling POI-related logic
"""

import sys
import os
# Add project root directory to system path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from math import atan2, cos, radians, sin, sqrt
import math
import requests
from typing import Dict, List, Any, Tuple
from utils.time_utils import parse_poi_spendtime, parse_poi_opentime, TimeUtils
import json
from datetime import datetime, timedelta
import pandas as pd


class POIAnalyzer:
    """POI Analyzer, responsible for handling POI-related logic"""
    
    def __init__(self, use_api=False):
        """Initialize POI analyzer"""
        # API URLs
        self.poi_detail_url = ""
        self.hotel_info_url = ''
        self.use_api = use_api
        self.transportation_pool = []
        self.poi_pool = []
        self.transportation_info = []
        self.hotel_pool = []

        
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        
        poi_file = os.path.join(project_root, 'data_base', 'poi_data_cleaned.json')
        hotel_file = os.path.join(project_root, 'data_base', 'hotel_data_cleaned.json')
        
        with open(poi_file, 'r', encoding='utf-8') as f:
            self.all_poi_info = json.load(f)
        with open(hotel_file, 'r', encoding='utf-8') as f:
            self.all_hotel_info = json.load(f)

    def read_poi_api_info(self, poiid: int, locale: str) -> tuple[dict[str, str | list | Any], Any, Any]:
        """
        Read POI API information

        Args:
            poiid: POI ID
            locale: Language setting

        Returns:
            (POI information, minimum visit time)
        """
        if self.use_api:
            request_body = {
                "head": {
                    "clientVersion": 862,
                    "platform": "32",
                    "source": "im.travel.assistant",
                    "locale": locale
                },
                "poiIdList": [poiid]
            }

            response = requests.post(self.poi_detail_url, json=request_body,
                                   headers={'SOA20-Client-AppId': '100021198'})
            poi_info = response.json()
        else:
            poi_info = self.all_poi_info.get(f"{str(poiid)}_{locale}", {})

        node_dict = {
            0: "otherPoiDetail", 2: "restaurantDetail", 3: "sightDetail", 5: "shopDetail",
            7: "airportDetail", 8: "portDetail", 9: "trainStationDetail", 10: "bigBusStationDetail",
            66: "sightPlayDetail", 70: "terminalDetail", 99: "activitiesDetail"
        }

        nodename_dict = {
            0: "PoiName", 2: "restaurantName", 3: "sightName", 5: "shopName",
            7: "airportName", 8: "portName", 9: "trainStationName", 10: "bigBusStationName",
            66: "sightPlayName", 70: "terminalName", 99: "activitiesName"
        }

        nodeEname_dict = {
            0: "PoiEName", 2: "restaurantEName", 3: "sightEName", 5: "shopEName",
            7: "airportEName", 8: "portEName", 9: "trainStationEName", 10: "bigBusStationEName",
            66: "sightPlayEName", 70: "terminalEName", 99: "activitiesEName"
        }
        if 'poiDetailList' not in poi_info or not poi_info['poiDetailList']:
            print("Failed to get POI information:", poiid)
            return {}, 0, {}
        poi_type = poi_info['poiDetailList'][0]['poiType']
        node = node_dict[poi_type]
        node_name = nodename_dict[poi_type]
        node_ename = nodeEname_dict[poi_type]

        try:
            cname = poi_info['poiDetailList'][0][node]['basicInfo'][node_name]
            ename = poi_info['poiDetailList'][0][node]['basicInfo'][node_ename]

        except:
            cname, ename = "", ""

        # POI tag
        try:
            min_spend_time, max_spend_time = parse_poi_spendtime(
                poi_info['poiDetailList'][0][node]['textInfo']['playSpendTime']
            )
            tag_list = [taginfo['tagName'] for taginfo in poi_info['poiDetailList'][0][node].get('tagInfoList', [])]

            poi_api_info = {
                'id': 'poi_{}'.format(str(poiid)),
                'lat': poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('coordinate',{}).get('latitude', ""),
                'lon': poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('coordinate',{}).get('longitude', ""),
                'cname': cname,
                'ename': ename,
                'districtName': poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('districtName', ""),
                'districtEName': poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('districtEName', ""),
                'playSpendTime': poi_info['poiDetailList'][0][node].get('textInfo',{}).get('playSpendTime', ""),
                "districtCoordinate": poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('districtCoordinate', ""),
                "tag": tag_list,
                "commentScore": poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('commentScore', ""),
                "address": poi_info['poiDetailList'][0][node].get('basicInfo',{}).get('address', ""),
                "introduction": poi_info['poiDetailList'][0][node].get('textInfo',{}).get('introductionText',""),
                "appointmentInfo": poi_info['poiDetailList'][0][node].get('appointmentInfo', []),
                "recommendReason": poi_info['poiDetailList'][0][node].get("rankingInfo", {}).get("recommendReason", ""),
                "openTimeInfo": poi_info['poiDetailList'][0][node].get("openTimeInfo", {}).get("tripOpenTimeRuleInfoList", []),
                "distanceInfo": poi_info['poiDetailList'][0][node].get("distanceInfo", {}.get("distanceInfo", "")),
                "price": poi_info['poiDetailList'][0][node].get("priceInfo", {}).get("price", 0)
            }
        except:
            tag_list = []
            poi_api_info = {'id': 'poi_{}'.format(str(poiid))}
            min_spend_time = pd.Timedelta(hours=1)

        return poi_api_info, min_spend_time, poi_info

    def check_poi_id_name_match(self, poi_api_info: Dict, item_name: str, locale: str) -> Tuple[bool, str, str]:
        """
        Check if POI ID matches the name

        Args:
            poi_api_info: POI API information
            item_name: Item name
            locale: Language setting

        Returns:
            (whether it doesn't match, name in API, name in itinerary)
        """
        try:
            if locale.split('-')[0] == 'zh':
                api_name = str(poi_api_info['cname'])
            else:
                api_name = str(poi_api_info['ename'])

            itinerary_name = str(item_name)
            return api_name != itinerary_name, api_name, itinerary_name
        except Exception as e:
            print("error in check_poi_id_name_match", str(e))
            return False, "", str(item_name)

    def get_hotel_info(self, hotel_id: str, locale: str='zh-CN') -> Dict:
        """
        Get hotel information

        Args:
            hotel_id: Hotel ID

        Returns:
            Hotel information
        """

        if self.use_api:
            search_nodes = {
                "mapHead": {"mapMakerType": "Mapbox"},
                "searchNodes": [{"type": "hotel", "id": str(hotel_id)}]
            }
            headers = {'Content-Type': 'application/json', 'user-agent': 'route-cal-service'}
            response = requests.post(self.hotel_info_url, headers=headers, json=search_nodes)
            hotel_info = response.json()
        else:
            hotel_info = self.all_hotel_info.get(f"{str(hotel_id)}", {})
        if not hotel_info or not hotel_info.get('hotels', []):
            print("hotel_info not found or empty", hotel_id)
            return {}
        hotel_api_info = {
            "id": hotel_id,
            "cname": hotel_info['hotels'][0]['summaryInfo']['hotelBasicInfo']['hotelName'],
            "ename": hotel_info['hotels'][0]['summaryInfo']['hotelBasicInfo']['hotelEnName'],
            "star": hotel_info['hotels'][0]['summaryInfo']['hotelStarInfo']['star'],
            "cityId": hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['cityId'],
            "cityName": hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['cityName'] if locale == 'zh-CN' else hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['cityNameEn'],
            "price": hotel_info['hotels'][0].get('priceInfo',{}).get('minPriceRoomInfo',{}).get('displayPrice',200),
            "currency": hotel_info['hotels'][0].get('priceInfo',{}).get('minPriceRoomInfo',{}).get('currency','CNY'),
            "commentInfo": hotel_info['hotels'][0]['summaryInfo']['commentInfo'],
            "tagList": hotel_info['hotels'][0].get('main_tags',[]),
            "lat": hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['hotelCoordinateInfo'][-1]['latitude'],
            "lon": hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['hotelCoordinateInfo'][-1]['longitude'],
            'zoneInfo': hotel_info['hotels'][0]['summaryInfo']['hotelPositionInfo']['zoneInfos'],
            'positionShowText': hotel_info['hotels'][0]['summaryInfo'].get('trafficInfo',{}).get('positionShowText', ''),
        }

        return hotel_api_info

    def check_hotel_id_name_match(self, hotel_info: Dict, item_name: str, locale: str) -> Tuple[bool, str, str]:
        """
        Check if hotel ID matches the name

        Args:
            hotel_info: Hotel information
            item_name: Item name
            locale: Language setting

        Returns:
            (whether it doesn't match, name in API, name in itinerary)
        """
        try:
            cname = hotel_info['name']
            ename = hotel_info['ename']
        except:
            return False, "", str(item_name)

        if locale.split('-')[0] == 'zh':
            api_name = str(cname)
        else:
            api_name = str(ename)

        itinerary_name = str(item_name)
        return api_name != itinerary_name, api_name, itinerary_name

    def check_transportation_id_period_match(self, transportation_id: str, period:str, locale: str) -> Tuple[bool, str, str]:
        """
        Check if Transportation ID matches the time period

        Args:
            transportation_id: Transportation ID
            period: Time in itinerary
            locale: Language setting

        Returns:
            (whether it doesn't match, time in API, time in itinerary)
        """
        try:
            for transportaion in self.transportation_info:
                if transportaion["planid"] == transportation_id:
                    depature_time = transportaion["depature_time"]
                    api_periods = ["Morning", "Afternoon", "Evening", "上午", "下午", "晚上"]
                    if depature_time != "":
                        _, api_periods = TimeUtils.get_time_of_day(depature_time)

                    if period not in api_periods:
                        return True, ", ".join(api_periods), period
            return False, "", period
        except Exception as e:
            print("match transportation  fail", str(e))
            return False, "", period

    def parse_poi_time_window(self, poi_info: Dict) -> str:
        """
        Parse POI opening time window

        Args:
            poi_info: POI information

        Returns:
            Opening time window string
        """
        node_dict = {
            0: "otherPoiDetail", 2: "restaurantDetail", 3: "sightDetail", 5: "shopDetail",
            7: "airportDetail", 8: "portDetail", 9: "trainStationDetail", 10: "bigBusStationDetail",
            66: "sightPlayDetail", 70: "terminalDetail", 99: "activitiesDetail"
        }

        poi_type = poi_info['poiDetailList'][0]['poiType']
        node = node_dict[poi_type]
        sight_detail = poi_info['poiDetailList'][0][node]

        if sight_detail.get('openTimeInfo', {}).get('tripOpenTimeRuleInfoList', None) is not None:
            poi_time_window = parse_poi_opentime(sight_detail['openTimeInfo']['tripOpenTimeRuleInfoList'])
        else:
            poi_time_window = "00:00-23:59"

        return poi_time_window
        
    def check_period_poiwindow(self, period, poi_time_window):
        incorrect_flag = True
        for checktime_range in poi_time_window.split(','):
            checktime_range = [cr for cr in checktime_range.split('-')]
            statime = datetime.strptime(checktime_range[0], "%H:%M")
            endtime = datetime.strptime(checktime_range[1], "%H:%M")
            if endtime<statime:
                endtime = endtime + timedelta(days=1)
            if period.lower() == 'morning' or period == "上午":
                if statime < datetime(statime.year, statime.month, statime.day, 12, 00):
                    incorrect_flag = False
            elif period.lower() == 'afternoon' or period == "下午":
                if endtime > datetime(statime.year, statime.month, statime.day, 12, 00) and statime < datetime(statime.year, statime.month, statime.day, 18, 00):
                    incorrect_flag = False
            elif period.lower() == 'evening' or period == "晚上":
                if endtime > datetime(statime.year, statime.month, statime.day, 18, 00):
                    incorrect_flag = False
        return incorrect_flag


    def calculate_poi_hours(self, poi_api_info: Dict) -> Tuple[float, float]:
        """
        Calculate POI playing time

        Args:
            poi_api_info: POI API information

        Returns:
            (Minimum playing time, Maximum playing time)    
        """
        try:
            min_spend_time, max_spend_time = parse_poi_spendtime(poi_api_info['playSpendTime'])

            if min_spend_time == "error":
                return 0.0, 0.0

            min_hours = float(min_spend_time.total_seconds() / 3600)
            max_hours = float(max_spend_time.total_seconds() / 3600)

            return min_hours, max_hours
        except Exception as e:
            print("error in calculate_poi_hours", str(e))
            return 1.0, 1.0

    def load_pool_from_dict(self, poi_dict):
        try:
            if isinstance(poi_dict["hotel_pool"], str):
                cur_hotel = json.loads(poi_dict["hotel_pool"])
                self.hotel_pool = [str(a["id"]) for area in cur_hotel["hotel_pool"] for a in area["hotels"]]
                hotel_name = []
                hotel_ename = []
                poi_name = []
                poi_ename = []
                for hotel_id in self.hotel_pool:
                    hotel_info = self.get_hotel_info(hotel_id, poi_dict["locale"])
                    hotel_name.append(hotel_info['cname'] if 'cname' in hotel_info else "未知酒店")
                    hotel_ename.append(hotel_info['ename'] if 'ename' in hotel_info else "未知酒店")
                self.hotel_name_pool = hotel_name + hotel_ename

                cur_poi = json.loads(poi_dict["poi_pool"])
                self.poi_pool = [str(a["id"]) for a in cur_poi["poi_pool"]]
                for poi_id in self.poi_pool:
                    try:
                        poi_id = int(poi_id)
                        poi_info = self.read_poi_api_info(poi_id, poi_dict["locale"])
                        poi_name.append(poi_info[0]['cname'] if 'cname' in poi_info[0] else "未知poi")
                        poi_ename.append(poi_info[0]['ename'] if 'ename' in poi_info[0] else "未知poi")
                    except Exception as e:
                        print("load poi pool info error 1", str(e))
                self.poi_name_pool = poi_name + poi_ename
            elif isinstance(poi_dict["hotel_pool"], list):
                hotel_name = []
                hotel_ename = []
                poi_name = []
                poi_ename = []
                self.hotel_pool = poi_dict["hotel_pool"]
                for hotel_id in self.hotel_pool:
                    hotel_info = self.get_hotel_info(hotel_id, poi_dict["locale"])
                    hotel_name.append(hotel_info['cname'] if 'cname' in hotel_info else "未知酒店")
                    hotel_ename.append(hotel_info['ename'] if 'ename' in hotel_info else "未知酒店")
                self.hotel_name_pool = hotel_name + hotel_ename
                self.poi_pool = poi_dict["poi_pool"]
                for poi_id in self.poi_pool:
                    try:
                        poi_id = int(poi_id)
                        poi_info = self.read_poi_api_info(poi_id, poi_dict["locale"])
                        poi_name.append(poi_info[0]['cname'] if 'cname' in poi_info[0] else "未知poi")
                        poi_ename.append(poi_info[0]['ename'] if 'ename' in poi_info[0] else "未知poi")
                    except Exception as e:
                        print("load poi pool info error 2", str(e))
                self.poi_name_pool = poi_name + poi_ename
        except Exception as e:
            print("load pool info error", str(e))
        self.load_transportation_from_dict(poi_dict)

    def load_transportation_from_dict(self, poi_dict):
        transportation_info = []
        depature = poi_dict["departure"]
        try:
            transportation_json = poi_dict["transport_pool"]
            transportation_json = json.loads(transportation_json)
            for key in transportation_json.keys():
                for transporation_way in transportation_json[key]:
                    try:
                        trans_type = transporation_way.get("type", "transportation_tool")
                        t_type = ""
                        key_list = key.split("->")
                        if len(key_list) == 2 and depature != "":
                            if depature.lower() in key_list[0].lower() and depature.lower() not in key_list[1].lower():
                                t_type = "to"
                            elif depature.lower() in key_list[1].lower() and depature.lower() not in key_list[0].lower():
                                t_type = "back"

                        if trans_type == "transportation_tool":
                            transporation_way["key"] = key
                            segments = [self.__parse_segement(segment) for segment in transporation_way['segments']]
                            cardType = [segment['tripType'] for segment in segments]
                            if len(segments) > 0:
                                depature_time = segments[0]["departureTime"]
                                flight_time = segments[-1]["arrivalTime"]
                                time_format = "%Y-%m-%d %H:%M"
                                depature_time_f = datetime.strptime(depature_time, time_format)
                                flight_time_f = datetime.strptime(flight_time, time_format)
                                for segment in segments:
                                    cur_depature_time_str = segment["departureTime"]
                                    cur_flight_time_str = segment["arrivalTime"]
                                    cur_depature_time = datetime.strptime(cur_depature_time_str, time_format)
                                    cur_flight_time = datetime.strptime(cur_flight_time_str, time_format)
                                    if cur_flight_time > flight_time_f:
                                        flight_time = cur_flight_time_str
                                        flight_time_f = cur_flight_time
                                    if cur_depature_time < depature_time_f:
                                        depature_time = cur_depature_time_str
                                        depature_time_f = cur_depature_time
                                if flight_time_f < depature_time_f:
                                    print("wrong_time_transportation", str(depature_time_f), str(flight_time_f))
                            else:
                                depature_time = ""
                                flight_time = ""
                            transporation_way["key"] = key

                            # else:
                            #     print("not find depature in key", depature, key)
                            transporation_way["t_type"] = t_type
                            transporation_way["depature_time"] = depature_time
                            transporation_way["flight_time"] = flight_time
                            transporation_way['planid'] = transporation_way['tripId4V1Hash']
                            transporation_way["segments"] = segments
                            transporation_way["cardType"] = cardType
                            planid = str(transporation_way['planid'])
                            transportation_info.append(transporation_way)
                        elif trans_type == "resource_pool":
                            if transporation_way['id'] == "":
                                continue
                            transporation_way["key"] = key
                            transporation_way["t_type"] = t_type
                            transporation_way["depature_time"] = transporation_way["segments"][0]["departureTime"]
                            transporation_way["flight_time"] = transporation_way["segments"][-1]["arrivalTime"]
                            transporation_way['planid'] = transporation_way['id']
                            transporation_way["cardType"] = transporation_way["segments"][0]["tripType"]
                            planid = str(transporation_way['planid'])
                            transportation_info.append(transporation_way)

                    except Exception as e:
                        print("parse transportation fail", e, transporation_way)

        except Exception as e:
            print("parse transportation fail", e)
        if len(transportation_info) < 1:
            print("not match transportation, transportation_pool empty")
        self.transportation_info = transportation_info
        self.transportation_pool = [a["planid"] for a in transportation_info]

    def get_transportation_info(self, transportation_id: str) -> Dict:
        """
        Get transportation information

        Args:
            transportation_id: Transportation ID

        Returns:
            Transportation information dictionary, if not found return empty dictionary
        """
        # First try to find by planid
        for transportation in self.transportation_info:
            if transportation.get("planid") == transportation_id:
                return transportation

        # If not found by planid, try to find by index in transportation_pool
        try:
            index = int(transportation_id)
            if 0 <= index < len(self.transportation_info):
                return self.transportation_info[index]
        except (ValueError, TypeError):
            pass

        return {}


    @staticmethod
    def __parse_segement(one_segment):
        if one_segment['tripType'] == 'T':
            return {'tripType': 'train', 'fromStationCode': one_segment['train']['fromCode'],
                    'toStationCode': one_segment['train']['toCode'],
                    'departureTime': one_segment['train']['depAt'], 'arrivalTime': one_segment['train']['arrAt'],
                    'trainNo': one_segment['train']['trainNo']
                    }
        elif one_segment['tripType'] == 'SC':
            return {'tripType': 'driving', 'fromStationCode': one_segment['point']['fromCode'],
                    'toStationCode': one_segment['point']['toCode'],
                    'departureTime': one_segment['point']['depAt'], 'arrivalTime': one_segment['point']['arrAt']}
        elif one_segment['tripType'] == 'B':
            return {'tripType': 'bus', 'fromStationCode': one_segment['bus']['fromCode'],
                    'toStationCode': one_segment['bus'].get('toCode', one_segment['bus'].get('toStation', '')),
                    'departureTime': one_segment['bus']['depEndAt'], 'arrivalTime': one_segment['bus']['arrAt']}
        elif one_segment['tripType'] == 'F':
            return {'tripType': 'flight', 'fromStationCode': one_segment['flight']['fromCode'],
                    'toStationCode': one_segment['flight']['toCode'],
                    'departureTime': one_segment['flight']['depAt'], 'arrivalTime': one_segment['flight']['arrAt'],
                    'flightNo': one_segment['flight']['flightNo']
                    }
        elif one_segment['tripType'] == 'S':
            return {'tripType': 'ship', 'fromStationCode': one_segment['ship']['fromCode'],
                    'toStationCode': one_segment['ship']['toCode'],
                    'departureTime': one_segment['ship']['depAt'], 'arrivalTime': one_segment['ship']['arrAt'],
                    'shipName': one_segment['ship']['shipName']
                    }
        else:
            print('Unrecognized tripType %s' % one_segment['tripType'])
            return one_segment


    def calculate_distance(self, lat1, lon1, lat2, lon2):
        # Calculate distance between two points
        # Use Haversine formula to calculate distance between two points
        R = 6371  # Earth radius in kilometers
        lat1, lon1, lat2, lon2 = map(radians, [float(lat1), float(lon1), float(lat2), float(lon2)])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
        c = 2 * atan2(sqrt(a), sqrt(1-a))
        distance = R * c
        return distance

    # Choose appropriate urban transportation based on distance (km), then calculate travel time
    def calculate_transportation_time(self, distance: float,type=None) -> float:
        """
        Choose appropriate urban transportation based on distance, then calculate travel time
        """
        speed_mapping = {
            "metro": 35,
            "taxi": 60,
            "bus": 35,
            "train": 35,
            "ship": 35,
            "drive": 60,
            "walking": 5,
            "walk": 5
        }

        if type is None:
            speed = 0
            # Take walking
            if distance < 2:
                speed = 5  # 5km per hour
            # Take subway
            elif distance < 20:
                speed = 35  # 35km per hour
            # Take driving
            else:
                speed = 60  # 60km per hour
            need_time = round(distance / speed, 1)
        else:
            speed = speed_mapping.get(type, 35)
            need_time = round(distance / speed, 1)

        return need_time

    def calculate_transportation_time_by_lat_lon(self, lat1, lon1, lat2, lon2,type=None) -> float:
        """
        Calculate travel time based on latitude and longitude
        """
        distance = self.calculate_distance(lat1, lon1, lat2, lon2)
        return self.calculate_transportation_time(distance,type)




if __name__ == "__main__":
    poi_analyzer = POIAnalyzer()
    one = (4.4721201, 101.3801441)
    two = (1.2800945, 103.8509491)
    cost = poi_analyzer.calculate_transportation_time_by_lat_lon(one[0], one[1], two[0], two[1],"metro")
    start_time = "15:06"
    end_time = TimeUtils.time_operation_add(start_time, cost)
    print(start_time, end_time)
