from typing import List, Dict, Any, Optional, Tuple
import json
import math
import os
import yaml
from datetime import datetime
from fuzzywuzzy import fuzz
from geopy.distance import geodesic

city_center_coords= {
    "Beijing": {
      "lon": 116.407387,
      "lat": 39.904179
    },
    "Changchun": {
      "lon": 125.323643,
      "lat": 43.816996
    },
    "Changsha": {
      "lon": 112.938882,
      "lat": 28.228304
    },
    "Chengdu": {
      "lon": 104.066301,
      "lat": 30.572961
    },
    "Chongqing": {
      "lon": 106.551787,
      "lat": 29.56268
    },
    "Dalian": {
      "lon": 121.614786,
      "lat": 38.913962
    },
    "Fuzhou": {
      "lon": 119.296411,
      "lat": 26.074286
    },
    "Guangzhou": {
      "lon": 113.264499,
      "lat": 23.130061
    },
    "Guilin": {
      "lon": 110.179752,
      "lat": 25.235615
    },
    "Guiyang": {
      "lon": 106.628201,
      "lat": 26.646694
    },
    "Haikou": {
      "lon": 110.200162,
      "lat": 20.046316
    },
    "Hangzhou": {
      "lon": 120.209903,
      "lat": 30.246566
    },
    "Harbin": {
      "lon": 126.53505,
      "lat": 45.802981
    },
    "Hong Kong": {
      "lon": 114.170714,
      "lat": 22.278354
    },
    "Jinan": {
      "lon": 117.120128,
      "lat": 36.652069
    },
    "Kaifeng": {
      "lon": 114.314278,
      "lat": 34.798083
    },
    "Kunming": {
      "lon": 102.833669,
      "lat": 24.88149
    },
    "Lijiang": {
      "lon": 100.225936,
      "lat": 26.855165
    },
    "Luoyang": {
      "lon": 112.453895,
      "lat": 34.619702
    },
    "Nanchang": {
      "lon": 115.857972,
      "lat": 28.682976
    },
    "Nanjing": {
      "lon": 118.796624,
      "lat": 32.059344
    },
    "Nanning": {
      "lon": 108.366407,
      "lat": 22.8177
    },
    "Ningbo": {
      "lon": 121.62454,
      "lat": 29.860258
    },
    "Qingdao": {
      "lon": 120.382665,
      "lat": 36.066938
    },
    "Sanya": {
      "lon": 109.511709,
      "lat": 18.252865
    },
    "Shanghai": {
      "lon": 121.473667,
      "lat": 31.230525
    },
    "Shenyang": {
      "lon": 123.464675,
      "lat": 41.677576
    },
    "Shenzhen": {
      "lon": 114.057939,
      "lat": 22.543527
    },
    "Suzhou": {
      "lon": 120.585294,
      "lat": 31.299758
    },
    "Taiyuan": {
      "lon": 112.549656,
      "lat": 37.870451
    },
    "Tianjin": {
      "lon": 117.201509,
      "lat": 39.085318
    },
    "Weihai": {
      "lon": 122.120519,
      "lat": 37.513315
    },
    "Wuhan": {
      "lon": 114.304569,
      "lat": 30.593354
    },
    "Wuxi": {
      "lon": 120.311889,
      "lat": 31.491064
    },
    "Xi'an": {
      "lon": 108.939645,
      "lat": 34.343207
    },
    "Xiamen": {
      "lon": 118.08891,
      "lat": 24.479627
    },
    "Xishuangbanna": {
      "lon": 100.797002,
      "lat": 22.009037
    },
    "Yantai": {
      "lon": 121.447755,
      "lat": 37.464551
    },
    "Zhengzhou": {
      "lon": 113.625351,
      "lat": 34.746303
    },
    "Zhuhai": {
      "lon": 113.576892,
      "lat": 22.271644
    }
  }

city_center_coords = {key.lower(): value for key, value in city_center_coords.items()}

class AttractionTool:

    def __init__(self, config_path: Optional[str] = None):
        current_dir = os.path.dirname(__file__)
        if config_path is None:
            config_path = os.path.join(os.path.dirname(current_dir), "config.yaml")

        with open(config_path, "r", encoding="utf-8") as f:
            cfg = yaml.safe_load(f) or {}

        data_cfg = cfg.get("data_path", {})
        attraction_path = data_cfg.get("attraction") or data_cfg.get("attractions")
        if not attraction_path:
            raise ValueError("Missing attraction path in config.yaml (data_path.attraction or data_path.attractions)")

        with open(attraction_path, "r", encoding="utf-8") as f:
            attractions_list: List[Dict[str, Any]] = json.load(f)

        self.attractions_list = attractions_list
        self.attractions = {str(a.get("poiId")): a for a in attractions_list}

    # ---------- 工具方法 ----------
    def _fuzzy_score(self, a: str, b: str) -> float:
        a = (a or "").lower()
        b = (b or "").lower()
        try:
            return float(fuzz.token_set_ratio(a, b))
        except Exception:
            return float(fuzz.ratio(a, b))

    def _match_city(self, a: Dict[str, Any], city: str) -> bool:
        return str(a.get("city", "")).lower() == (city or "").strip().lower()

    def _calculate_distance(self, lat1, lon1, lat2, lon2):
        return geodesic((lat1, lon1), (lat2, lon2)).km
    
    def _filter_sight_level(self, item, sight_level):
        level = item.get("sightLevelStr")

        if sight_level == "4A":
            return level in ("4A", "5A")
        elif sight_level == "5A":
            return level == "5A"
        else:
            return level == sight_level

    # ---------- 核心搜索 ----------
    def search_attractions(
        self,
        city: str,
        attraction_name: Optional[str] = None,
        categories: Optional[List[str]] = None,
        longitude: Optional[float] = None,
        latitude: Optional[float] = None,
        distance_threshold: Optional[float] = 50,      # km
        rating: Optional[float] = None,               # 4.5 / 4.0
        sight_level: Optional[str] = None,            # "4A"/"4A or higher"/"5A"
        comment_count: Optional[int] = None,
        free_only: bool = False,
        sort_key: Optional[str] = None,                # "commentScore" / "heatScore" / "distance"
        sort_order: str = "default",                  # "asc" / "desc" / "default"
        page: int = 1,
        page_size: int = 10,
    ):

        # --------- 兼容字符串输入：数值/布尔/JSON列表 ---------

        # categories：只支持 JSON 字符串
        if isinstance(categories, str):
            try:
                categories = json.loads(categories)
            except Exception:
                categories = None
        if isinstance(categories, list) and categories:
            categories = [str(x).strip() for x in categories if str(x).strip()]
        else:
            categories = None

        # 经纬度 / 距离阈值：允许 "123" / "none"
        if isinstance(latitude, str):
            s = latitude.strip().lower()
            latitude = None if s in ("", "none", "null") else float(latitude)
        if isinstance(longitude, str):
            s = longitude.strip().lower()
            longitude = None if s in ("", "none", "null") else float(longitude)
        if isinstance(distance_threshold, str):
            s = distance_threshold.strip().lower()
            distance_threshold = None if s in ("", "none", "null") else float(distance_threshold)

        # rating：允许 "4.0"/"4.5"
        if isinstance(rating, str):
            s = rating.strip().lower()
            rating = None if s in ("", "none", "null") else float(rating)

        # comment_count：允许 "100"
        if isinstance(comment_count, str):
            s = comment_count.strip().lower()
            comment_count = None if s in ("", "none", "null") else int(float(comment_count))

        # free_only：只支持 "true"/"false"
        if isinstance(free_only, str):
            s = free_only.strip().lower()
            if s == "true":
                free_only = True
            elif s == "false":
                free_only = False
            else:
                free_only = False  # 或者 None；但你签名是 bool，这里用 False 更稳

        # page / page_size：允许 "1"
        try:
            page = max(1, int(float(page)))
        except Exception:
            page = 1
        try:
            page_size = max(1, int(float(page_size)))
        except Exception:
            page_size = 10

        # -------- 1. 初始列表 + 城市过滤 --------
        items: List[Dict[str, Any]] = list(self.attractions_list)
        # print(len(items))


        city = city.strip().lower()
        items = [a for a in items if self._match_city(a, city)]
        if not items:
            return "No attractions found."

        # -------- 2. 模糊匹配：先 attraction_name，再 categories --------
        kw_query: Optional[str] = None
        kw_mode: Optional[str] = None  # "name" 或 "category"

        if attraction_name:
            kw_query = attraction_name.strip().lower()
            kw_mode = "name"
        elif categories:
            kw_query = " ".join(categories).strip().lower()
            kw_mode = "category"

        if kw_query:
            kept: List[Dict[str, Any]] = []
            for a in items:
                if kw_mode == "name":
                    target = str(a.get("poiName") or "")
                else:  # category 模糊
                    cats = ", ".join(a.get("categories") or [])
                    tags = str(a.get("tagNameList") or "")
                    target = f"{cats} {tags}".strip().lower()

                score = self._fuzzy_score(kw_query, target)
                # print(score, target)
                if score >= 60:
                    b = dict(a)
                    b["_kw_score"] = score
                    kept.append(b)

            items = kept

        if not items:
            return "No attractions found."

        # -------- 3. 距离计算：经纬度优先，否则用 city_center_coords --------
        origin_lat: Optional[float] = None
        origin_lon: Optional[float] = None

        if latitude is not None and longitude is not None:
            origin_lat = float(latitude)
            origin_lon = float(longitude)
        elif city and city in city_center_coords:
            origin_lat = float(city_center_coords[city]["lat"])
            origin_lon = float(city_center_coords[city]["lon"])

        if origin_lat is not None and origin_lon is not None:
            with_distance: List[Dict[str, Any]] = []
            for a in items:
                long_val = a.get("longitude")
                lat_val = a.get("latitude")
                if long_val is None or lat_val is None:
                    if distance_threshold is None:
                        with_distance.append(a)
                    continue

                d = self._calculate_distance(origin_lat, origin_lon, float(lat_val), float(long_val))

                if distance_threshold is not None and d > float(distance_threshold):
                    continue

                b = dict(a)
                b["_distance"] = d
                with_distance.append(b)

            items = with_distance

        if not items:
            return "No attractions found."

        # -------- 4. 评分 / 等级 / 评论数 / 是否免费 过滤 --------
        if rating is not None:
            try:
                thr = float(rating)
            except (TypeError, ValueError):
                thr = 0.0
            items = [a for a in items if float(a.get("commentScore") or 0.0) >= thr]

        if sight_level is not None:
            items = [a for a in items if self._filter_sight_level(a, sight_level)]

        if comment_count is not None:
            items = [a for a in items if int(a.get("commentCount") or 0) >= int(comment_count)]

        if free_only:
            filtered: List[Dict[str, Any]] = []
            for a in items:
                price_val = a.get("price")
                try:
                    is_free = price_val is None or float(price_val) == 0.0
                except (TypeError, ValueError):
                    is_free = price_val is None
                if is_free:
                    filtered.append(a)
            items = filtered

        total = len(items)
        if total == 0:
            return "No attractions found."

        # -------- 5. 排序逻辑（已按需求调整） --------
        sort_order_norm = (sort_order or "default").lower()
        has_kw_score = any("_kw_score" in a for a in items)

        # ① 指定了 sort_key = distance
        if sort_key == "distance":
            # distance：default = asc（升序）
            reverse = (sort_order_norm == "desc")
            items = sorted(
                items,
                key=lambda a: float(a.get("_distance", 1e9)),
                reverse=reverse,
            )

        # ② 指定 sort_key = commentScore / heatScore
        elif sort_key in {"commentScore", "heatScore"}:
            # 指定字段：default = desc
            reverse = (sort_order_norm in {"default", "desc", ""})
            items = sorted(
                items,
                key=lambda a: float(a.get(sort_key) or 0.0),
                reverse=reverse,
            )

        # ③ 未指定 sort_key：
        else:
            if has_kw_score:
                # 有关键词 → 默认按匹配度降序；匹配度相同再按 heat_score 降序
                reverse = (sort_order_norm in {"default", "desc", ""})

                items = sorted(
                    items,
                    key=lambda a: (
                        float(a.get("_kw_score", 0.0)),
                        float(a.get("heatScore", 0.0)),
                    ),
                    reverse=reverse,
                )
            else:
                # 无关键词 → 默认按 heatScore 降序
                reverse = (sort_order_norm in {"default", "desc", ""})
                items = sorted(
                    items,
                    key=lambda a: float(a.get("heatScore") or 0.0),
                    reverse=reverse,
                )


        # -------- 6. 分页 + 文本输出 --------
        start = (page - 1) * page_size
        end = start + page_size
        page_items = items[start:end]

        lines: List[str] = [
            f"Showing {start+1}-{min(end, total)} of {total} results."
        ]

        for attraction in page_items:
            id_ = attraction.get("poiId", "N/A")
            name = attraction.get("poiName", "N/A")
            city_val = attraction.get("city", "N/A")
            long_val = attraction.get("longitude", "N/A")
            lat_val = attraction.get("latitude", "N/A")
            categories_str = ", ".join(attraction.get("categories", [])) or "N/A"
            sight_level_val = attraction.get("sightLevelStr", "N/A")
            score = attraction.get("commentScore", "N/A")
            heat = attraction.get("heatScore", "N/A")
            comment_cnt = attraction.get("commentCount", "N/A")
            price_val = attraction.get("price", "N/A")

            ref_str = attraction.get("reference_time_raw")

            open_hours = attraction.get("opening_hours")

            if open_hours and open_hours.get("open") and open_hours.get("close"):
                hours_str = f"{open_hours['open']} – {open_hours['close']}"
            else:
                hours_str = "N/A"

            short_features = attraction.get("shortFeatures", "N/A")

            distance = attraction.get("_distance", None)
            # print(distance)
            if distance is not None:
                distance_str = f"distance: {distance:.2f}km"
            else:
                distance_str = ""

            info = (
                f"POI ID: {id_} | {city_val} | {name} | level: {sight_level_val} | "
                f"longitude: {long_val}, latitude: {lat_val} | "
                f"{distance_str}\n"
                # f"categories: {categories_str}\n"
                f"rating: {score} ({comment_cnt} comments) | popularity Score: {heat} | "
                f"opening hours: {hours_str} | reference visit time: {ref_str} | "
                f"ticket price: {price_val if price_val else 'Free'}"
            )
            lines.append(info)

        if not page_items:
            lines.append("(No results on this page)")

        return "\n".join(lines)

    def get_attraction_detail_with_products(self, poi_id):
        """
        获取景点详情及其产品信息
        
        参数:
            poi_id: 景点ID            
        返回:
            格式化的景点详情和产品信息
        """
        # 查找景点
        attraction = self.attractions.get(str(poi_id))
        if not attraction:
            return f"Attraction not found: {poi_id}"
        
        # 构建摘要行
        summary_line = (
            f"POI ID: {attraction.get('poiId')} | {attraction.get('city')} | {attraction.get('poiName')} | "
            f"level: {attraction.get('sightLevelStr') or 'N/A'} | "
            f"longitude: {attraction.get('longitude')}, latitude: {attraction.get('latitude')}\n"
            f"categories: {', '.join(attraction.get('categories', [])) or 'N/A'}\n"
            f"rating: {attraction.get('commentScore')} ({attraction.get('commentCount')} comments) | "
            f"popularity score: {attraction.get('heatScore')} | "
            f"opening hours: {attraction.get('opening_hours', {}).get('open', 'N/A')} – {attraction.get('opening_hours', {}).get('close', 'N/A')} | "
            f"reference visit time: {attraction.get('reference_time_raw')} | "
            f"features: {attraction.get('shortFeatures', 'N/A')}"

        )
        
        # 获取产品信息
        ticket_products = attraction.get("ticket_products", []) or []
        product_lines = []
        
        for product in ticket_products:
            product_lines.append(
                f"product_id: {product.get('product_id')} | {product.get('type')} | "
                f"price: {product.get('price')}"
            )
        
        if not product_lines:
            # 没有匹配的产品
            product_lines.append("This attraction does not require any ticket purchase.")
        
        return "\n".join([summary_line] + product_lines)
    
    def get_attraction_coordinates(self, poi_id):
        """
        根据景点ID返回经纬度坐标
        
        参数:
            poi_id: 景点ID
            
        返回:
            格式化的坐标信息
        """
        
        # 精确匹配
        attraction = self.attractions.get(str(poi_id))
        if not attraction:
            return f"No matching attraction found for ID: {poi_id}"
        
        poi_name = attraction.get('poiName', 'N/A')       
        return (
                f"Attraction: {poi_name} | "
                f"latitude: {attraction['latitude']} | longitude: {attraction['longitude']}"
            )


if __name__ == "__main__":
    tools = AttractionTool()
    result = tools.search_attractions(
        city="beijing",
        # attraction_name="Songhua River",
        # categories = "Cultural",
        # sight_level="4A",
        page=1,
        page_size=10,
        # rating=4.5,
        # sort_key="distance",
        # sort_order="desc",
        # distance_threshold=2.0,
    )
    print(result)

    result = tools.get_attraction_detail_with_products(poi_id="81704")
    print(result)
    
    result = tools.get_attraction_coordinates(poi_id="81704")
    print(result)

