from __future__ import annotations
from typing import List, Dict, Any, Optional, Tuple
import json
import math
import os
from typing import List, Dict, Any, Optional
from geopy.distance import geodesic
from fuzzywuzzy import fuzz
import yaml

small_cate_list = [
  "Anhui Cuisine",
  "Barbecue",
  "Beijing Cuisine",
  "Buffet",
  "Chaoshan Cuisine",
  "Crayfish",
  "Creative Cuisine",
  "Farmhouse Cuisine",
  "Fast Food",
  "Fujian Cuisine",
  "Guangdong cuisine",
  "Guangxi Cuisine",
  "Guizhou Cuisine",
  "Hainan Cuisine",
  "Hakka Cuisine",
  "Henan Cuisine",
  "Home-style Cooking",
  "Hot Pot",
  "Huaiyang Cuisine",
  "Hubei Cuisine",
  "Hunan Cuisine",
  "Japanese Cuisine",
  "Jiangsu and Zhejiang Cuisine",
  "Jiangxi Cuisine",
  "Korean Cuisine",
  "Malatang",
  "Northeastern Cuisine",
  "Northwestern Cuisine",
  "Other Chinese Cuisine",
  "Other Delicacies",
  "Pizza",
  "Porridge Shop",
  "Private Kitchen",
  "Rice Noodles",
  "Seafood",
  "Shaanxi Cuisine",
  "Shandong Cuisine",
  "Shanxi Cuisine",
  "Sichuan Cuisine",
  "Snacks",
  "Southeast Asian Cuisine",
  "Taiwanese Cuisine",
  "Tea Restaurant",
  "Tianjin Cuisine",
  "Vegetarian Cuisine",
  "Western Cuisine",
  "Wontons and Dumplings",
  "Xinjiang Cuisine",
  "Yunnan and Guizhou Cuisine"
]

city_center_coords= {
    "Beijing": {
      "lon": 116.407387,
      "lat": 39.904179
    },
    "Changchun": {
      "lon": 125.323643,
      "lat": 43.816996
    },
    "Changsha": {
      "lon": 112.938882,
      "lat": 28.228304
    },
    "Chengdu": {
      "lon": 104.066301,
      "lat": 30.572961
    },
    "Chongqing": {
      "lon": 106.551787,
      "lat": 29.56268
    },
    "Dalian": {
      "lon": 121.614786,
      "lat": 38.913962
    },
    "Fuzhou": {
      "lon": 119.296411,
      "lat": 26.074286
    },
    "Guangzhou": {
      "lon": 113.264499,
      "lat": 23.130061
    },
    "Guilin": {
      "lon": 110.179752,
      "lat": 25.235615
    },
    "Guiyang": {
      "lon": 106.628201,
      "lat": 26.646694
    },
    "Haikou": {
      "lon": 110.200162,
      "lat": 20.046316
    },
    "Hangzhou": {
      "lon": 120.209903,
      "lat": 30.246566
    },
    "Harbin": {
      "lon": 126.53505,
      "lat": 45.802981
    },
    "Hong Kong": {
      "lon": 114.170714,
      "lat": 22.278354
    },
    "Jinan": {
      "lon": 117.120128,
      "lat": 36.652069
    },
    "Kaifeng": {
      "lon": 114.314278,
      "lat": 34.798083
    },
    "Kunming": {
      "lon": 102.833669,
      "lat": 24.88149
    },
    "Lijiang": {
      "lon": 100.225936,
      "lat": 26.855165
    },
    "Luoyang": {
      "lon": 112.453895,
      "lat": 34.619702
    },
    "Nanchang": {
      "lon": 115.857972,
      "lat": 28.682976
    },
    "Nanjing": {
      "lon": 118.796624,
      "lat": 32.059344
    },
    "Nanning": {
      "lon": 108.366407,
      "lat": 22.8177
    },
    "Ningbo": {
      "lon": 121.62454,
      "lat": 29.860258
    },
    "Qingdao": {
      "lon": 120.382665,
      "lat": 36.066938
    },
    "Sanya": {
      "lon": 109.511709,
      "lat": 18.252865
    },
    "Shanghai": {
      "lon": 121.473667,
      "lat": 31.230525
    },
    "Shenyang": {
      "lon": 123.464675,
      "lat": 41.677576
    },
    "Shenzhen": {
      "lon": 114.057939,
      "lat": 22.543527
    },
    "Suzhou": {
      "lon": 120.585294,
      "lat": 31.299758
    },
    "Taiyuan": {
      "lon": 112.549656,
      "lat": 37.870451
    },
    "Tianjin": {
      "lon": 117.201509,
      "lat": 39.085318
    },
    "Weihai": {
      "lon": 122.120519,
      "lat": 37.513315
    },
    "Wuhan": {
      "lon": 114.304569,
      "lat": 30.593354
    },
    "Wuxi": {
      "lon": 120.311889,
      "lat": 31.491064
    },
    "Xi'an": {
      "lon": 108.939645,
      "lat": 34.343207
    },
    "Xiamen": {
      "lon": 118.08891,
      "lat": 24.479627
    },
    "Xishuangbanna": {
      "lon": 100.797002,
      "lat": 22.009037
    },
    "Yantai": {
      "lon": 121.447755,
      "lat": 37.464551
    },
    "Zhengzhou": {
      "lon": 113.625351,
      "lat": 34.746303
    },
    "Zhuhai": {
      "lon": 113.576892,
      "lat": 22.271644
    }
  }

for i in range(len(small_cate_list)):
    small_cate_list[i] = small_cate_list[i].lower()

city_center_coords = {key.lower(): value for key, value in city_center_coords.items()}

class RestaurantTool:
    """
    优化版餐厅搜索工具
    - 城市 fuzzy 匹配提前计算
    - 搜索经纬度默认使用市中心坐标（和 HotelTool 一致）
    - 城市索引提前构建
    """

    # -------------------------
    # 初始化（和 HotelTool 风格一致）
    # -------------------------
    def __init__(
        self,
        config_path: Optional[str] = None
    ):
        current_dir = os.path.dirname(__file__)
        if config_path is None:
            config_path = os.path.join(os.path.dirname(current_dir), "config.yaml")

        with open(config_path, "r", encoding="utf-8") as f:
            cfg = yaml.safe_load(f)

        restaurant_path = cfg.get("data_path", {}).get("restaurant")
        if not restaurant_path:
            raise ValueError("Missing restaurant path in config.yaml (data_path.restaurant)")

        with open(restaurant_path, "r", encoding="utf-8") as f:
            self.restaurants = json.load(f)

        self.location_index = city_center_coords
        # 构建城市索引
        self.rebuild_index()

    # -------------------------
    # 内部工具函数
    # -------------------------
    def _norm_city(self, s: str) -> str:
        return (s or "").strip().lower()

    def _fuzzy_match_city(self, city_key: str) -> Optional[str]:
        """和 HotelTool 一样，对 city_center_coords 做 fuzzy 匹配"""
        best_city = None
        best_score = 0
        for cname in city_center_coords:
            score = fuzz.ratio(city_key.lower(), cname.lower())
            if score > best_score:
                best_score = score
                best_city = cname
        return best_city if best_score >= 60 else None

    def _calculate_distance(self, lat1, lon1, lat2, lon2):
        return geodesic((lat1, lon1), (lat2, lon2)).km

    def rebuild_index(self):
        index = {}
        index_restaurant = {}
        for r in self.restaurants:
            c = self._norm_city(r.get("real_city", ""))
            id = str(r.get("id"))
            index_restaurant[id] = r
            if c:
                index.setdefault(c, []).append(r)
        self.city_index = index
        self.index_restaurant = index_restaurant

    # -------------------------
    # 公开搜索接口
    # -------------------------
    def search_restaurants(
        self,
        city,
        longitude=None,
        latitude=None,
        distance_threshold=5.0,
        price_min=0.0,
        price_max=9999999.0,
        stars=0.0,
        review_count=0,
        product_rating=0.0,
        environment_rating=0.0,
        service_rating=0.0,
        categories=None,
        reservable=None,
        sort_key="stars",
        sort_order: str = "default",
        page=1,
        page_size=10,
    ):
        # --------- 兼容字符串输入：数值/布尔/JSON列表 ---------

        # 数值
        if isinstance(longitude, str):
            s = longitude.strip().lower()
            longitude = None if s in ("", "none", "null") else float(longitude)
        if isinstance(latitude, str):
            s = latitude.strip().lower()
            latitude = None if s in ("", "none", "null") else float(latitude)
        if isinstance(distance_threshold, str):
            s = distance_threshold.strip().lower()
            distance_threshold = None if s in ("", "none", "null") else float(distance_threshold)

        try: price_min = float(price_min)
        except: price_min = 0.0
        try: price_max = float(price_max)
        except: price_max = 9999999.0

        try: stars = float(stars)
        except: stars = 0.0
        try: review_count = int(float(review_count))
        except: review_count = 0

        try: product_rating = float(product_rating)
        except: product_rating = 0.0
        try: environment_rating = float(environment_rating)
        except: environment_rating = 0.0
        try: service_rating = float(service_rating)
        except: service_rating = 0.0

        # 布尔：只支持 "true"/"false"
        if isinstance(reservable, str):
            s = reservable.strip().lower()
            if s == "true":
                reservable = True
            elif s == "false":
                reservable = False
            else:
                reservable = None

        # 列表：categories 只支持 JSON 字符串
        if isinstance(categories, str):
            try:
                categories = json.loads(categories)
            except:
                categories = None  # 解析失败：当作不筛选
        # 规范化：None/空 => 不筛选；list => lower 后的新 list
        if isinstance(categories, list) and categories:
            categories = [str(x).strip().lower() for x in categories if str(x).strip()]
        else:
            categories = None  # 用 None 表示不筛选

        # -------------- 分页参数清理 --------------
        try:
            page = max(1, int(float(page)))
        except Exception:
            page = 1

        try:
            page_size = max(1, int(float(page_size)))
        except Exception:
            page_size = 10


        # -------------- 城市索引 --------------
        city = self._norm_city(city)
        if city not in self.city_index:
            return "Invalid city name."
        candidate_restaurants = self.city_index[city]

        # -------------- 搜索经纬度（与 HotelTool 同逻辑） --------------
        search_lon = search_lat = None

        if longitude is not None and latitude is not None:
            # 1) 直接用用户输的经纬度
            search_lon, search_lat = float(longitude), float(latitude)
        else:
            # print(city)
            # print(self.location_index)
            # 2) 如果 location_index 里有对应条目（比如 city 其实是一个 location_id）
            if city in self.location_index:
                search_lon, search_lat = self.location_index[city]["lon"], self.location_index[city]["lat"]
            # 3) 否则 fuzzy 匹配城市名，用 city_center_coords 的市中心
            elif city:
                city_key = city.strip().title()
                match_city = self._fuzzy_match_city(city_key)
                if match_city:
                    search_lon = city_center_coords[match_city]["lon"]
                    search_lat = city_center_coords[match_city]["lat"]

        need_distance = (search_lon is not None and search_lat is not None)

        # -------------- 主循环过滤 --------------
        rows = []
        for r in candidate_restaurants:
            get = r.get

            # --- 类目 small_cate 过滤 ---
            if categories:
                if get("small_cate").lower() not in categories:
                    continue

            # --- 评分条件 ---
            if get("stars", 0.0) < stars:
                continue
            if get("review_count", 0) < review_count:
                continue
            if get("product_rating", 0.0) < product_rating:
                continue
            if get("environment_rating", 0.0) < environment_rating:
                continue
            if get("service_rating", 0.0) < service_rating:
                continue

            # --- 预订条件 ---
            if reservable is not None and get("reservable") != reservable:
              continue

            # --- 价格（avg_price）过滤 ---
            avg_price = get("avg_price")
            if avg_price is None:
                continue
            try:
                avg_price = float(avg_price)
            except Exception:
                continue

            if not (price_min <= avg_price <= price_max):
                continue

            # --- 距离计算（如果有中心点） ---
            distance = float("inf")
            if need_distance:
                r_lat = get("latitude")
                r_lon = get("longitude")
                if r_lat is None or r_lon is None:
                    continue
                try:
                    distance = self._calculate_distance(
                        float(search_lat),
                        float(search_lon),
                        float(r_lat),
                        float(r_lon),
                    )
                except Exception:
                    continue

                if distance_threshold is not None and distance > distance_threshold:
                    continue

            rows.append(
                {
                    "_price": avg_price,
                    "_stars": get("stars", 0.0),
                    "_review_count": get("review_count", 0),
                    "avg_price": round(avg_price, 2),
                    "distance": round(distance, 2),
                    "data": r,
                }
            )

        if not rows:
            return "No matching restaurants found."

        # -------------- 排序 --------------
        sort_key = (sort_key or "stars").lower()
        sort_order = (sort_order or "default").lower()

        if sort_order == "default":
            # price / distance 默认升序；stars / review_count 默认降序
            if sort_key in ["price", "distance"]:
                reverse = False
            elif sort_key in ["stars", "review_count"]:
                reverse = True
            else:
                reverse = True
        else:
            reverse = (sort_order == "desc")

        if sort_key == "price":
            rows.sort(key=lambda r: r["_price"], reverse=reverse)
        elif sort_key == "stars":
            rows.sort(key=lambda r: r["_stars"], reverse=reverse)
        elif sort_key == "review_count":
            rows.sort(key=lambda r: r["_review_count"], reverse=reverse)
        elif sort_key == "distance":
            rows.sort(key=lambda r: r["distance"], reverse=reverse)
        else:
            rows.sort(key=lambda r: r["_stars"], reverse=reverse)

        # -------------- 分页 --------------
        total = len(rows)
        start = (page - 1) * page_size
        end = min(start + page_size, total)
        page_rows = rows[start:end]

        formatted = [f"Showing {start + 1}-{end} of {total} results."]

        # for r in page_rows:
        #     data = r["data"]
        #     formatted.append(
        #         "Restaurant_id: {rid} | {name} | category: {cate} | "
        #         "avg_price: {price} | stars: {stars}/5 | review_count: {rc} | "
        #         "longitude: {lon}, latitude: {lat} | distance: {dist}km".format(
        #             rid=data.get("restaurant_id") or data.get("id"),
        #             name=data.get("name"),
        #             cate=data.get("small_cate"),
        #             price=r["avg_price"],
        #             stars=data.get("stars"),
        #             rc=int(data.get("review_count", 0)),
        #             lon=data.get("longitude"),
        #             lat=data.get("latitude"),
        #             dist=r["distance"],
        #         )
        #     )
        for r in page_rows:
          data = r["data"]

          # --- opening hours 格式化 ---
          oh = data.get("open_hours")
          if isinstance(oh, list) and oh:
              open_hours = ", ".join(
                  f"{s}-{e}" for (s, e) in oh
                  if isinstance(oh, list) and isinstance((s, e), tuple)  # 这行可去掉，不影响
              )
          else:
              open_hours = "N/A"

          formatted.append(
              "Restaurant_id: {rid} | {name} | category: {cate} | "
              "avg_price: {price} | stars: {stars}/5 | review_count: {rc} | "
              "open_hours: {oh} | "
              "longitude: {lon}, latitude: {lat} | distance: {dist}km".format(
                  rid=data.get("restaurant_id") or data.get("id"),
                  name=data.get("name"),
                  cate=data.get("small_cate"),
                  price=r["avg_price"],
                  stars=data.get("stars"),
                  rc=int(data.get("review_count", 0)),
                  oh=open_hours,
                  lon=data.get("longitude"),
                  lat=data.get("latitude"),
                  dist=r["distance"],
              )
          )

        return "\n".join(formatted)

    def get_restaurant_detail_with_products(
        self,
        restaurant_id,
    ):
        """
        返回：
        - 第一行：search 风格的摘要
        - 后续多行：所有通过筛选的产品字符串，每条包含
            product_id、people、meal、price、available_time_ranges、max_purchase_qty

        规则：
        - 默认 meal="all" 表示不过滤餐别；否则会匹配产品的 meal 或 "all"
        - 若传入 people，则只保留 people 完全匹配的产品
        - 若传入 time_str（"HH:MM"），则要求：
            1）餐厅 open_hours 覆盖该时间
            2）产品 available_time_ranges 也覆盖该时间
          否则视为无对应可售产品
        - 若找不到餐厅，返回 "Restaurant not found: {restaurant_id}"
        - 若餐厅在该条件下无任何产品，返回：
            "No product available for {restaurant_id} (people=..., meal=..., time=...)."
        """
       
        # --- 查找餐厅 ---
        restaurant = None
        rid_str = str(restaurant_id)

        restaurant = self.index_restaurant.get(rid_str)

        if not restaurant:
            return f"Restaurant not found: {restaurant_id}"

        open_hours = ", ".join(
                f"{s}-{e}" for (s, e) in restaurant.get("open_hours", []) if isinstance(restaurant.get("open_hours"), list)
            )
        # --- 摘要行 ---
        summary_line = (
            f"Restaurant_id: {restaurant.get('restaurant_id') or restaurant.get('id')} | "
            f"{restaurant.get('name')} | "
            f"category: {restaurant.get('small_cate')} | "
            f"avg_price: {restaurant.get('avg_price')} | "
            f"stars: {restaurant.get('stars')}/5 | "
            f"review_count: {restaurant.get('review_count')} | "
            f"longitude: {restaurant.get('longitude')}, latitude: {restaurant.get('latitude')}\n"
            f"open_hours: {open_hours} | "
            f"product_rating: {restaurant.get('product_rating')} | "
            f"environment_rating: {restaurant.get('environment_rating')} | "
            f"service_rating: {restaurant.get('service_rating')} | "
            f"reservable: {restaurant.get('reservable')} | "
            f"must_reserve: {restaurant.get('must_reserve')} | "
        )



        products = restaurant.get("products", []) or []
        product_lines = []

        # 按人数 people 升序排序
        products = sorted(products, key=lambda x: x.get("people", 0))

        for p in products:
            # 组装时间段字符串
            ranges = p.get("available_time_ranges") or []
            ranges_str = ", ".join(
                f"{s}-{e}" for (s, e) in ranges if isinstance(ranges, list)
            )

            # 组装产品行
            product_lines.append(
                f"product_id: {p.get('product_id')} | "
                f"people: {p.get('people')} | "
                f"price: {p.get('price')} | "
                f"available_time_ranges: {ranges_str}"
            )

        if not product_lines:
            product_lines.append(
                f"No products are available for {restaurant.get('name')}. Please order directly at the restaurant."
            )
                

        return "\n".join([summary_line] + product_lines)
    
    def get_restaurant_coordinates(self, restaurant_id):
        """
        根据餐厅 ID 返回经纬度坐标，并以 summary 风格返回

        Args:
            restaurant_id: 字符串或数字，比如 "restaurant_000001" 或原始 id

        Returns:
            - 找到时：
              "Restaurant_id: ... | name: ... | city: ... | latitude: ... | longitude: ..."
            - 找不到时：
              "Restaurant not found: {restaurant_id}"
        """

        rid = str(restaurant_id)
        restaurant = self.index_restaurant.get(rid)
        
        if restaurant is None:
            return f"Restaurant not found: {restaurant_id}"

        return (
            f"Restaurant_id: {restaurant.get('restaurant_id') or restaurant.get('id')} | "
            f"name: {restaurant.get('name')} | "
            f"latitude: {restaurant.get('latitude')} | "
            f"longitude: {restaurant.get('longitude')}"
        )


if __name__ == "__main__":
    tools = RestaurantTool()
    result = tools.search_restaurants(
        city="Guilin",
        categories=["Hunan Cuisine"],
        # longitude= 109.530166,
        # latitude= 18.281204,
        # categories=["snacks"],
    )
    print(result)

    # result = tools.get_restaurant_detail_with_products(
    #     restaurant_id="restaurant_5468269",
    # )
    # print(result)

    # result = tools.get_restaurant_coordinates(
    #     restaurant_id="restaurant_5468269",
    # )
    # print(result)
