import json
import datetime
import yaml
import os
from fuzzywuzzy import fuzz

class FlightTool:
    """
    航班搜索工具（带出发地/目的地二级索引）

    加速点：
    - 初始化后构建 self.index: { (dep_city_norm, arr_city_norm): [flight_dict, ...] }
    - search() 直接在索引命中的候选集合里做过滤，避免每次全表扫描 O(N)
    """

    def __init__(self, config_path=None):
        # 获取当前脚本的绝对路径
        current_dir = os.path.dirname(__file__)
        # 拼接出 config.yaml 的路径（假设在 Tools 下）
        if config_path is None:
            config_path = os.path.join(os.path.dirname(current_dir), "config.yaml")

        with open(config_path, "r", encoding="utf-8") as f:
            cfg = yaml.safe_load(f)

        flight_path = cfg.get("data_path", {}).get("flight")
        if not flight_path:
            raise ValueError("Missing flight path in config.yaml (data_path.flight)")

        with open(flight_path, "r", encoding="utf-8") as f:
            self.flights = json.load(f)

        # === 建立(出发地, 目的地)索引 ===
        self.rebuild_index()
        # === 建立机场坐标索引 ===
        self._build_airport_coordinates_index()

    def _build_airport_coordinates_index(self):
        """
        构建机场坐标索引，存储所有机场的经纬度信息
        格式: self.airport_coords = {"airport_name": {"latitude": float, "longitude": float}, ...}
        """
        self.airport_coords = {}
        
        for flight in self.flights:
            # 处理出发机场
            dep_airport = flight.get("Departure Airport")
            dep_lat = flight.get("Departure Airport Latitude")
            dep_lon = flight.get("Departure Airport Longitude")
            if dep_airport and dep_lat is not None and dep_lon is not None:
                self.airport_coords[dep_airport] = {
                    "latitude": float(dep_lat),
                    "longitude": float(dep_lon)
                }
            
            # 处理到达机场
            arr_airport = flight.get("Arrival Airport")
            arr_lat = flight.get("Arrival Airport Latitude")
            arr_lon = flight.get("Arrival Airport Longitude")
            if arr_airport and arr_lat is not None and arr_lon is not None:
                self.airport_coords[arr_airport] = {
                    "latitude": float(arr_lat),
                    "longitude": float(arr_lon)
                }

    # ---------- 工具函数 ----------

    def _norm_city(self, s: str) -> str:
        """统一城市字符串：去空格并小写"""
        return (s or "").strip().lower()

    def _weekday_abbr(self, date_str: str) -> str:
        """将 YYYY-MM-DD 转为 Mon/Tue/..."""
        date_obj = datetime.datetime.strptime(date_str, "%Y-%m-%d").date()
        week_abbr = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
        return week_abbr[date_obj.weekday()]

    def _parse_hhmm(self, s: str) -> datetime.time:
        """
        健壮解析 'HH:MM'、'H:M'、'0900'、'900' 等为 time 对象；
        把 24:00 按 00:00 处理（对应次日零点，这里仅用于比较）。
        """
        s = (s or "").strip()
        if not s:
            raise ValueError("Empty time string")
        if ":" in s:
            h_str, m_str = s.split(":", 1)
        else:
            # 允许 '900' / '0900'
            s = s.zfill(4)
            h_str, m_str = s[:2], s[2:]
        h, m = int(h_str), int(m_str)
        if h == 24 and m == 0:
            h, m = 0, 0
        if not (0 <= h <= 23 and 0 <= m <= 59):
            raise ValueError(f"Invalid time: {s}")
        return datetime.time(h, m)

    def _parse_period(self, period_str: str) -> tuple:
        """
        解析 'HH:MM-HH:MM'；将 24:00 作为开区间上界的极限（23:59:59.999999）。
        """
        start_s, end_s = [p.strip() for p in period_str.split("-")]
        # start
        if start_s == "24:00":
            start = datetime.time(0, 0)
        else:
            start = self._parse_hhmm(start_s)
        # end（开区间）
        if end_s == "24:00":
            end = datetime.time(23, 59, 59, 999999)
        else:
            end = self._parse_hhmm(end_s)
        return start, end

    def _price_for_day(self, f: dict, week_abbr: str):
        """
        获取当天价格：优先取 Mon_lowest_price 等，没有则回退到 Price。
        """
        day_key = f"{week_abbr}_lowest_price"
        if day_key in f and f[day_key] is not None:
            return f[day_key]
        return f.get("Price")

    # ---------- 索引维护 ----------

    def rebuild_index(self):
        """
        重新构建(出发地, 目的地)索引。
        当 self.flights 更新（增删改）后可调用此方法。
        """
        index = {}
        for f in self.flights:
            dep = self._norm_city(f.get("Departure City", ""))
            arr = self._norm_city(f.get("Arrival City", ""))
            if not dep or not arr:
                continue
            index.setdefault((dep, arr), []).append(f)
        self.index = index

    # ---------- 公开接口 ----------

    def search_flights(self, departure_city, arrival_city, date_str, dep_period="00:00-24:00", arr_period="00:00-24:00", 
                       sort_key="time", sort_order="asc", page=1, page_size=10):
        """
        搜索航班

        参数:
            departure_city: 出发城市
            arrival_city: 到达城市
            date_str: 日期 (YYYY-MM-DD)
            dep_period: 出发时间范围 (如 "06:00-12:00")
            arr_period: 到达时间范围 (如 "00:00-24:00")
            sort_key: 排序字段，可选 ["time", "price"]
            sort_order: 排序方向，可选 ["asc", "desc"]
            page: 页码（从 1 开始）
            page_size: 每页数量

        返回:
            List[str]，第一行是汇总，其后每行一个结果。
        """
            # page: str -> int
        if isinstance(page, str):
            page = page.strip()
            page = int(page) if page else 1

        # page_size: str -> int
        if isinstance(page_size, str):
            page_size = page_size.strip()
            page_size = int(page_size) if page_size else 10

        # --- 参数与时间窗解析 ---
        dep_start, dep_end = self._parse_period(dep_period)
        arr_start, arr_end = self._parse_period(arr_period)
        week_abbr = self._weekday_abbr(date_str)
        available_key = f"{week_abbr}_available"

        # 统一城市名规范（与索引一致）
        departure_city = self._norm_city(departure_city)
        arrival_city = self._norm_city(arrival_city)

        # 页面参数兜底
        try:
            page = max(1, int(page))
        except Exception:
            page = 1
        try:
            page_size = max(1, int(page_size))
        except Exception:
            page_size = 10

        # --- 候选集：直接命中索引，避免全表扫描 ---
        candidate_flights = self.index.get((departure_city, arrival_city), [])

        rows = []
        for f in candidate_flights:
            # 可用性（按星期）
            if f.get(available_key) != 1:
                continue

            # 解析时间
            try:
                dep_time = self._parse_hhmm(str(f.get("Departure Time", "")))
                arr_time = self._parse_hhmm(str(f.get("Arrival Time", "")))
            except Exception:
                continue  # 时间格式异常跳过

            # 检查时间窗口（end 为开区间）
            if not (dep_start <= dep_time < dep_end):
                continue
            if not (arr_start <= arr_time < arr_end):
                continue

            # 当天价格
            price_today = self._price_for_day(f, week_abbr)

            rows.append({
                "_dep_time": dep_time,
                "_arr_time": arr_time,
                "_price": price_today,
                "_day": week_abbr,
                "data": f,
            })

        if not rows:
            return "No matching flights found."

        # --- 排序 ---
        reverse = (str(sort_order).lower() == "desc")
        if sort_key == "price":
            rows.sort(
                key=lambda r: (float(r["_price"]) if r["_price"] is not None else float("inf")),
                reverse=reverse
            )
        else:  # 默认按出发时间
            rows.sort(key=lambda r: r["_dep_time"], reverse=reverse)

        # --- 分页 ---
        total = len(rows)
        start_idx = (page - 1) * page_size
        end_idx = start_idx + page_size
        page_rows = rows[start_idx:end_idx]

        # --- 格式化输出 ---
        formatted = []
        for r in page_rows:
            f = r["data"]
            price_display = r["_price"]
            formatted.append(
                f"Flight_id: {f.get('Flight_id')} | {f.get('Flight Number')} | {f.get('Airline')} | "
                f"{f.get('Departure Time')}-{f.get('Arrival Time')} | "
                f"{f.get('Departure Airport')} -> {f.get('Arrival Airport')} | "
                f"minimum price: {price_display}"
                # f"minimum price: {price_display} | "
                # f"on-time rate: {f.get('On-Time Performance')} | "
                # f"average delay: {f.get('Average Delay (minutes)')} min"
            )

        summary = f"Showing {start_idx + 1}-{min(end_idx, total)} of {total} results."
        return "\n".join([summary] + formatted)

    def get_flight_detail_with_products(self, flight_id, date_str, source_platform=["ctrip", "alitrip", "qunar", "direct"], seat_type="Economy class"):
        """
        返回：
        - 第一行：search 风格的摘要 + 准点率/平均延误
        - 后续多行：所有通过筛选的产品字符串，每条包含
            product_id、seat_type、source_platform、当天价格、手提行李额、托运行李额度

        规则：
        - 默认 platform=["ctrip", "alitrip", "qunar", "direct"]、seat_type="Economy class"
        - 若航班当日无最低价或不可售，直接返回 ["No flight available on {date_str} for {flight_id}."]
        - 产品必须有当日价格（Mon..Sun 对应），否则视为当日无此产品
        """

        if isinstance(source_platform, str):
            s = source_platform.strip()
            if not s:
                source_platform = None
            else:
                source_platform = json.loads(s)
                
        # --- 查找航班 ---
        flight = None
        for f in getattr(self, "flights", []):
            if str(f.get("Flight_id")) == str(flight_id):
                flight = f
                break
        if not flight:
            return f"Flight not found: {flight_id}"

        # --- 当日判断 ---
        week_abbr = self._weekday_abbr(date_str)           # 例如 "Mon"
        available_key = f"{week_abbr}_available"
        day_lowest_key = f"{week_abbr}_lowest_price"

        # 航班当日必须可售且有最低价
        if flight.get(available_key) != 1 or flight.get(day_lowest_key) is None:
            return f"No flight available on {date_str} for {flight_id}."

        price_today = flight.get(day_lowest_key)

        # --- 摘要行 ---
        summary_line = (
            f"Flight_id: {flight.get('Flight_id')} | {flight.get('Flight Number')} | {flight.get('Airline')} | "
            f"{flight.get('Departure Time')}-{flight.get('Arrival Time')} | "
            # f"{flight.get('Departure Airport')} (longitude={flight.get('Departure Airport Longitude')}, latitude={flight.get('Departure Airport Latitude')}) -> "
            # f"{flight.get('Arrival Airport')} (longitude={flight.get('Arrival Airport Longitude')}, latitude={flight.get('Arrival Airport Latitude')}) | "
            f"{flight.get('Departure Airport')} -> {flight.get('Arrival Airport')} | " 
            f"minimum price: {price_today} | "
            f"on-time rate: {flight.get('On-Time Performance')} | "
            f"average delay: {flight.get('Average Delay (minutes)')} min"
        )

        # --- 筛选产品 ---
        def _norm(x): return (str(x) if x is not None else "").strip().lower()

        platform_norm = [_norm(p) for p in source_platform]
        seat_norm = _norm(seat_type)

        products = flight.get("Product", []) or []
        product_lines = []
        for p in products:
            # 舱位必须匹配
            if _norm(p.get("seat_type")) != seat_norm:
                continue
            # 平台匹配（all 不筛）
            if _norm(p.get("source_platform")) not in platform_norm:
                continue
            # 必须有当日价格
            day_price = p.get(week_abbr)
            if day_price is None:
                continue

            # 组装产品行（含行李额度）
            product_lines.append(
                f"product_id: {p.get('product_id')} | {p.get('seat_type')} | {p.get('source_platform')} | "
                f"price: {day_price} | "
                # f"Carry-on: {p.get('carry_on_weight')}kg, {p.get('carry_on_pieces')}pcs, {p.get('carry_on_size')}inch | "
                # f"Checked: {p.get('checked_weight')}kg, {p.get('checked_pieces')}pcs, {p.get('checked_size')}inch"
            )

        if not product_lines:
            # 没有任何当日有价的产品 => 视为当日无此班可售
            return f"No flight available on {date_str} for {flight_id}."

        return "\n".join([summary_line] + product_lines)

    def get_airport_coordinates(self, airport_name):
        """
        根据机场名称模糊匹配返回经纬度坐标，并以 summary 风格返回
        """
        if not airport_name or not self.airport_coords:
            return "No airport coordinates available."
        
        # 精确匹配
        if airport_name in self.airport_coords:
            coords = self.airport_coords[airport_name]
            return (
                f"Airport: {airport_name} | "
                f"latitude: {coords['latitude']} | longitude: {coords['longitude']}"
            )
        
        # 模糊匹配
        best_match = None
        highest_score = 0

        for name in self.airport_coords:
            score = fuzz.partial_ratio(airport_name.lower(), name.lower())
            if score > highest_score:
                highest_score = score
                best_match = name

        threshold = 80

        if best_match and highest_score >= threshold:
            coords = self.airport_coords[best_match]
            return (
                f"Airport: {best_match} | "
                f"latitude: {coords['latitude']} | longitude: {coords['longitude']} | "
            )
        
        return "No matching airport found."

    

# 使用示例
if __name__ == "__main__":
    tool = FlightTool()
    flights = tool.search_flights(
        "Beijing", 
        "Changchun", 
        "2025-10-20",
        dep_period="00:00-24:00",
        sort_key="time",     # "time" 或 "price"
        sort_order="asc",     # "asc" 或 "desc"
        page=1,
        page_size=10
    )
    print(flights)
    flight_info = tool.get_flight_detail_with_products(
        flight_id="Flight_00000001",
        date_str="2025-10-20",
        source_platform=["ctrip", "alitrip"],
        # seat_type="Business class"
    )
    print(flight_info)

    # 测试机场坐标匹配
    # print(tool.get_airport_coordinates("Beijing"))  # 北京
    # print(tool.get_airport_coordinates("Changchun"))  # 长春
    print(tool.get_airport_coordinates("hongqiao"))  # 上海
    print(tool.get_airport_coordinates("Guangzhou"))  # 广州

