import json
import datetime
import yaml
import os
from fuzzywuzzy import fuzz

class TrainTool:
    """
    列车搜索工具（带出发地/目的地二级索引）

    加速点：
    - 初始化后构建 self.index: { (dep_city_norm, arr_city_norm): [train_dict, ...] }
    - search() 直接在索引命中的候选集合里做过滤，避免每次全表扫描 O(N)
    """

    def __init__(self, config_path=None):
        # 获取当前脚本的绝对路径
        current_dir = os.path.dirname(__file__)
        # 拼接出 config.yaml 的路径（假设在 Tools 下）
        if config_path is None:
            config_path = os.path.join(os.path.dirname(current_dir), "config.yaml")

        # 尝试从配置文件读取，如果失败则使用默认路径
        try:
            with open(config_path, "r", encoding="utf-8") as f:
                cfg = yaml.safe_load(f)
            train_path = cfg.get("data_path", {}).get("train")
            if not train_path:
                raise ValueError("Missing train path in config.yaml (data_path.train)")
        except Exception:
            train_path = os.path.join(os.path.dirname(os.path.dirname(current_dir)), "data_final", "trains.json")

        if not os.path.isabs(train_path):
            project_root = os.path.dirname(os.path.dirname(current_dir))
            train_path = os.path.normpath(os.path.join(project_root, train_path))

        with open(train_path, "r", encoding="utf-8") as f:
            self.trains = json.load(f)

        # === 建立(出发地, 目的地)索引 ===
        self.rebuild_index()
        # === 建立车站坐标索引 ===
        self._build_station_coordinates_index()

    def _build_station_coordinates_index(self):
        """
        构建车站坐标索引，存储所有车站的经纬度信息
        格式: self.station_coords = {"station_name": {"latitude": float, "longitude": float}, ...}
        """
        self.station_coords = {}
        
        for train in self.trains:
            # 处理出发车站
            dep_station = train.get("Departure Station")
            dep_lat = train.get("Departure Station Latitude")
            dep_lon = train.get("Departure Station Longitude")
            if dep_station and dep_lat is not None and dep_lon is not None:
                self.station_coords[dep_station] = {
                    "latitude": float(dep_lat),
                    "longitude": float(dep_lon)
                }
            
            # 处理到达车站
            arr_station = train.get("Arrival Station")
            arr_lat = train.get("Arrival Station Latitude")
            arr_lon = train.get("Arrival Station Longitude")
            if arr_station and arr_lat is not None and arr_lon is not None:
                self.station_coords[arr_station] = {
                    "latitude": float(arr_lat),
                    "longitude": float(arr_lon)
                }

    # ---------- 工具函数 ----------

    def _norm_city(self, s: str) -> str:
        """统一城市字符串：去空格并小写"""
        return (s or "").strip().lower()

    def _weekday_abbr(self, date_str: str) -> str:
        """将 YYYY-MM-DD 转为 Mon/Tue/..."""
        date_obj = datetime.datetime.strptime(date_str, "%Y-%m-%d").date()
        week_abbr = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
        return week_abbr[date_obj.weekday()]

    def _parse_hhmm(self, s: str) -> datetime.time:
        """
        健壮解析 'HH:MM'、'H:M'、'0900'、'900' 等为 time 对象；
        把 24:00 按 00:00 处理（对应次日零点，这里仅用于比较）。
        """
        s = (s or "").strip()
        if not s:
            raise ValueError("Empty time string")
        if ":" in s:
            h_str, m_str = s.split(":", 1)
        else:
            # 允许 '900' / '0900'
            s = s.zfill(4)
            h_str, m_str = s[:2], s[2:]
        h, m = int(h_str), int(m_str)
        if h == 24 and m == 0:
            h, m = 0, 0
        if not (0 <= h <= 23 and 0 <= m <= 59):
            raise ValueError(f"Invalid time: {s}")
        return datetime.time(h, m)

    def _parse_period(self, period_str: str) -> tuple:
        """
        解析 'HH:MM-HH:MM'；将 24:00 作为开区间上界的极限（23:59:59.999999）。
        """
        start_s, end_s = [p.strip() for p in period_str.split("-")]
        # start
        if start_s == "24:00":
            start = datetime.time(0, 0)
        else:
            start = self._parse_hhmm(start_s)
        # end（开区间）
        if end_s == "24:00":
            end = datetime.time(23, 59, 59, 999999)
        else:
            end = self._parse_hhmm(end_s)
        return start, end

    # ---------- 索引维护 ----------

    def rebuild_index(self):
        """
        重新构建(出发地, 目的地)索引。
        当 self.trains 更新（增删改）后可调用此方法。
        """
        index = {}
        for t in self.trains:
            dep = self._norm_city(t.get("Departure City", ""))
            arr = self._norm_city(t.get("Arrival City", ""))
            if not dep or not arr:
                continue
            index.setdefault((dep, arr), []).append(t)
        self.index = index

    # ---------- 公开接口 ----------

    def search_trains(self, departure_city, arrival_city, date_str, dep_period="00:00-24:00", arr_period="00:00-24:00", 
                      sort_key="time", sort_order="asc", page=1, page_size=10):
        """
        搜索列车

        参数:
            departure_city: 出发城市
            arrival_city: 到达城市
            date_str: 日期 (YYYY-MM-DD)
            dep_period: 出发时间范围 (如 "06:00-12:00")
            arr_period: 到达时间范围 (如 "00:00-24:00")
            sort_key: 排序字段，可选 ["time", "price"]
            sort_order: 排序方向，可选 ["asc", "desc"]
            page: 页码（从 1 开始）
            page_size: 每页数量

        返回:
            List[str]，第一行是汇总，其后每行一个结果。
        """

        if isinstance(page, str):
            page = page.strip()
            page = int(page) if page else 1

        # page_size: str -> int
        if isinstance(page_size, str):
            page_size = page_size.strip()
            page_size = int(page_size) if page_size else 10

        # --- 参数与时间窗解析 ---
        dep_start, dep_end = self._parse_period(dep_period)
        arr_start, arr_end = self._parse_period(arr_period)
        
        # 统一城市名规范（与索引一致）
        departure_city = self._norm_city(departure_city)
        arrival_city = self._norm_city(arrival_city)

        # 页面参数兜底
        try:
            page = max(1, int(page))
        except Exception:
            page = 1
        try:
            page_size = max(1, int(page_size))
        except Exception:
            page_size = 10

        # --- 候选集：直接命中索引，避免全表扫描 ---
        candidate_trains = self.index.get((departure_city, arrival_city), [])

        rows = []
        for t in candidate_trains:
            # 解析时间
            try:
                dep_time = self._parse_hhmm(str(t.get("Departure Time", "")))
                arr_time = self._parse_hhmm(str(t.get("Arrival Time", "")))
            except Exception:
                continue  # 时间格式异常跳过

            # 检查时间窗口（end 为开区间）
            if not (dep_start <= dep_time < dep_end):
                continue
            if not (arr_start <= arr_time < arr_end):
                continue

            rows.append({
                "_dep_time": dep_time,
                "_arr_time": arr_time,
                "_price": t.get("Price"),
                "data": t,
            })

        if not rows:
            return "No matching trains found."

        # --- 排序 ---
        reverse = (str(sort_order).lower() == "desc")
        if sort_key == "price":
            rows.sort(
                key=lambda r: (float(r["_price"]) if r["_price"] is not None else float("inf")),
                reverse=reverse
            )
        else:  # 默认按出发时间
            rows.sort(key=lambda r: r["_dep_time"], reverse=reverse)

        # --- 分页 ---
        total = len(rows)
        start_idx = (page - 1) * page_size
        end_idx = start_idx + page_size
        page_rows = rows[start_idx:end_idx]

        # --- 格式化输出 ---
        formatted = []
        for r in page_rows:
            t = r["data"]
            price_display = r["_price"]
            formatted.append(
                f"Train_id: {t.get('Train_id')} | {t.get('Train Number')} | "
                f"{t.get('Departure Time')}-{t.get('Arrival Time')} | "
                f"{t.get('Departure Station')} -> {t.get('Arrival Station')} | "
                f"minimum price: {price_display}"
            )

        summary = f"Showing {start_idx + 1}-{min(end_idx, total)} of {total} results."
        return "\n".join([summary] + formatted)

    def get_train_detail_with_products(self, train_id, date_str, source_platform=["ctrip", "alitrip", "qunar", "direct"], seat_type="Second class"):
        """
        返回：
        - 第一行：search 风格的摘要
        - 后续多行：所有通过筛选的产品字符串，每条包含
            product_id、seat_type、source_platform、价格、数量

        规则：
        - 默认 platform=["ctrip", "alitrip", "qunar", "direct"]、seat_type="Second class"
        - 若列车不存在，直接返回 ["Train not found: {train_id}"]
        """

        if isinstance(source_platform, str):
            s = source_platform.strip()
            if not s:
                source_platform = None
            else:
                source_platform = json.loads(s)
                
        # --- 查找列车 ---
        train = None
        for t in getattr(self, "trains", []):
            if str(t.get("Train_id")) == str(train_id):
                train = t
                break
        if not train:
            return [f"Train not found: {train_id}"]

        # --- 摘要行 ---
        summary_line = (
            f"Train_id: {train.get('Train_id')} | {train.get('Train Number')} | "
            f"{train.get('Departure Time')}-{train.get('Arrival Time')} | "
            # f"{train.get('Departure Station')} (longitude: {train.get('Departure Station Longitude')}, latitude: {train.get('Departure Station Latitude')}) -> "
            # f"{train.get('Arrival Station')} (longitude: {train.get('Arrival Station Longitude')}, latitude: {train.get('Arrival Station Latitude')})"
            f"{train.get('Departure Station')} -> {train.get('Arrival Station')}"
        )

        # --- 筛选产品 ---
        def _norm(x): return (str(x) if x is not None else "").strip().lower()

        platform_norm = [_norm(p) for p in source_platform]
        seat_norm = _norm(seat_type)

        products = train.get("Product", []) or []
        product_lines = []
        for p in products:
            # 舱位必须匹配
            if _norm(p.get("seat_type")) != seat_norm:
                continue
            # 平台匹配（all 不筛）
            if _norm(p.get("source_platform")) not in platform_norm:
                continue

            # 组装产品行
            product_lines.append(
                f"product_id: {p.get('product_id')} | {p.get('seat_type')} | {p.get('source_platform')} | "
                f"price: {p.get('price')} |"
            )

        if not product_lines:
            return f"No matching products found for train {train_id}."

        return "\n".join([summary_line] + product_lines)

    def get_station_coordinates(self, station_name):
        """
        根据车站名称模糊匹配返回经纬度坐标，并以 summary 风格返回
        """
        if not station_name or not self.station_coords:
            return "No station coordinates available."
        
        # 精确匹配
        if station_name in self.station_coords:
            coords = self.station_coords[station_name]
            return (
                f"Station: {station_name} | "
                f"latitude: {coords['latitude']} | longitude: {coords['longitude']}"
            )
        
        # 模糊匹配
        best_match = None
        highest_score = 0

        for name in self.station_coords:
            score = fuzz.partial_ratio(station_name.lower(), name.lower())
            if score > highest_score:
                highest_score = score
                best_match = name

        threshold = 80

        if best_match and highest_score >= threshold:
            coords = self.station_coords[best_match]
            return (
                f"Station: {best_match} | "
                f"latitude: {coords['latitude']} | longitude: {coords['longitude']}"
            )
        
        return "No matching station found."


# 使用示例
if __name__ == "__main__":
    tool = TrainTool()
    trains = tool.search_trains(
        "Beijing", 
        "Shanghai", 
        "2025-10-20",
        dep_period="00:00-24:00",
        sort_key="time",     # "time" 或 "price"
        sort_order="asc",     # "asc" 或 "desc"
        page=1,
        page_size=10
    )
    print(trains)
    
    print("测试列车详情")
    # 测试列车详情
    train_info = tool.get_train_detail_with_products(
        train_id="Train_00000001",
        date_str="2025-10-20",
        source_platform=["ctrip", "alitrip"],
        # seat_type="Business class"
    )
    print(train_info)

    # 测试车站坐标匹配
    print(tool.get_station_coordinates("Beijing South Railway Station"))  # 北京南站
    print(tool.get_station_coordinates("Tianjin Station"))  # 天津站

city_center_coords= {
    "Beijing": {
      "lon": 116.407387,
      "lat": 39.904179
    },
    "Changchun": {
      "lon": 125.323643,
      "lat": 43.816996
    },
    "Changsha": {
      "lon": 112.938882,
      "lat": 28.228304
    },
    "Chengdu": {
      "lon": 104.066301,
      "lat": 30.572961
    },
    "Chongqing": {
      "lon": 106.551787,
      "lat": 29.56268
    },
    "Dalian": {
      "lon": 121.614786,
      "lat": 38.913962
    },
    "Fuzhou": {
      "lon": 119.296411,
      "lat": 26.074286
    },
    "Guangzhou": {
      "lon": 113.264499,
      "lat": 23.130061
    },
    "Guilin": {
      "lon": 110.179752,
      "lat": 25.235615
    },
    "Guiyang": {
      "lon": 106.628201,
      "lat": 26.646694
    },
    "Haikou": {
      "lon": 110.200162,
      "lat": 20.046316
    },
    "Hangzhou": {
      "lon": 120.209903,
      "lat": 30.246566
    },
    "Harbin": {
      "lon": 126.53505,
      "lat": 45.802981
    },
    "Hong Kong": {
      "lon": 114.170714,
      "lat": 22.278354
    },
    "Jinan": {
      "lon": 117.120128,
      "lat": 36.652069
    },
    "Kaifeng": {
      "lon": 114.314278,
      "lat": 34.798083
    },
    "Kunming": {
      "lon": 102.833669,
      "lat": 24.88149
    },
    "Lijiang": {
      "lon": 100.225936,
      "lat": 26.855165
    },
    "Luoyang": {
      "lon": 112.453895,
      "lat": 34.619702
    },
    "Nanchang": {
      "lon": 115.857972,
      "lat": 28.682976
    },
    "Nanjing": {
      "lon": 118.796624,
      "lat": 32.059344
    },
    "Nanning": {
      "lon": 108.366407,
      "lat": 22.8177
    },
    "Ningbo": {
      "lon": 121.62454,
      "lat": 29.860258
    },
    "Qingdao": {
      "lon": 120.382665,
      "lat": 36.066938
    },
    "Sanya": {
      "lon": 109.511709,
      "lat": 18.252865
    },
    "Shanghai": {
      "lon": 121.473667,
      "lat": 31.230525
    },
    "Shenyang": {
      "lon": 123.464675,
      "lat": 41.677576
    },
    "Shenzhen": {
      "lon": 114.057939,
      "lat": 22.543527
    },
    "Suzhou": {
      "lon": 120.585294,
      "lat": 31.299758
    },
    "Taiyuan": {
      "lon": 112.549656,
      "lat": 37.870451
    },
    "Tianjin": {
      "lon": 117.201509,
      "lat": 39.085318
    },
    "Weihai": {
      "lon": 122.120519,
      "lat": 37.513315
    },
    "Wuhan": {
      "lon": 114.304569,
      "lat": 30.593354
    },
    "Wuxi": {
      "lon": 120.311889,
      "lat": 31.491064
    },
    "Xi'an": {
      "lon": 108.939645,
      "lat": 34.343207
    },
    "Xiamen": {
      "lon": 118.08891,
      "lat": 24.479627
    },
    "Xishuangbanna": {
      "lon": 100.797002,
      "lat": 22.009037
    },
    "Yantai": {
      "lon": 121.447755,
      "lat": 37.464551
    },
    "Zhengzhou": {
      "lon": 113.625351,
      "lat": 34.746303
    },
    "Zhuhai": {
      "lon": 113.576892,
      "lat": 22.271644
    }
  }
