from datetime import datetime, timedelta
from geopy.distance import geodesic


city_center_coords= {
    "Beijing": {
      "lon": 116.407387,
      "lat": 39.904179
    },
    "Changchun": {
      "lon": 125.323643,
      "lat": 43.816996
    },
    "Changsha": {
      "lon": 112.938882,
      "lat": 28.228304
    },
    "Chengdu": {
      "lon": 104.066301,
      "lat": 30.572961
    },
    "Chongqing": {
      "lon": 106.551787,
      "lat": 29.56268
    },
    "Dalian": {
      "lon": 121.614786,
      "lat": 38.913962
    },
    "Fuzhou": {
      "lon": 119.296411,
      "lat": 26.074286
    },
    "Guangzhou": {
      "lon": 113.264499,
      "lat": 23.130061
    },
    "Guilin": {
      "lon": 110.179752,
      "lat": 25.235615
    },
    "Guiyang": {
      "lon": 106.628201,
      "lat": 26.646694
    },
    "Haikou": {
      "lon": 110.200162,
      "lat": 20.046316
    },
    "Hangzhou": {
      "lon": 120.209903,
      "lat": 30.246566
    },
    "Harbin": {
      "lon": 126.53505,
      "lat": 45.802981
    },
    "Hong Kong": {
      "lon": 114.170714,
      "lat": 22.278354
    },
    "Jinan": {
      "lon": 117.120128,
      "lat": 36.652069
    },
    "Kaifeng": {
      "lon": 114.314278,
      "lat": 34.798083
    },
    "Kunming": {
      "lon": 102.833669,
      "lat": 24.88149
    },
    "Lijiang": {
      "lon": 100.225936,
      "lat": 26.855165
    },
    "Luoyang": {
      "lon": 112.453895,
      "lat": 34.619702
    },
    "Nanchang": {
      "lon": 115.857972,
      "lat": 28.682976
    },
    "Nanjing": {
      "lon": 118.796624,
      "lat": 32.059344
    },
    "Nanning": {
      "lon": 108.366407,
      "lat": 22.8177
    },
    "Ningbo": {
      "lon": 121.62454,
      "lat": 29.860258
    },
    "Qingdao": {
      "lon": 120.382665,
      "lat": 36.066938
    },
    "Sanya": {
      "lon": 109.511709,
      "lat": 18.252865
    },
    "Shanghai": {
      "lon": 121.473667,
      "lat": 31.230525
    },
    "Shenyang": {
      "lon": 123.464675,
      "lat": 41.677576
    },
    "Shenzhen": {
      "lon": 114.057939,
      "lat": 22.543527
    },
    "Suzhou": {
      "lon": 120.585294,
      "lat": 31.299758
    },
    "Taiyuan": {
      "lon": 112.549656,
      "lat": 37.870451
    },
    "Tianjin": {
      "lon": 117.201509,
      "lat": 39.085318
    },
    "Weihai": {
      "lon": 122.120519,
      "lat": 37.513315
    },
    "Wuhan": {
      "lon": 114.304569,
      "lat": 30.593354
    },
    "Wuxi": {
      "lon": 120.311889,
      "lat": 31.491064
    },
    "Xi'an": {
      "lon": 108.939645,
      "lat": 34.343207
    },
    "Xiamen": {
      "lon": 118.08891,
      "lat": 24.479627
    },
    "Xishuangbanna": {
      "lon": 100.797002,
      "lat": 22.009037
    },
    "Yantai": {
      "lon": 121.447755,
      "lat": 37.464551
    },
    "Zhengzhou": {
      "lon": 113.625351,
      "lat": 34.746303
    },
    "Zhuhai": {
      "lon": 113.576892,
      "lat": 22.271644
    }
  }

city_center_coords = {key.lower(): value for key, value in city_center_coords.items()}


class GeneralTool:
    """
    使用累进式速度模型计算交通时间：
    - geopy 计算直线距离
    - 多段速度累进，避免区间跳变导致速度不连续
    - 返回分钟
    """

    def __init__(self):
        # (区间上限，速度 km/h)
        self.speed_rules = [
            (1, 5),       # 步行
            (10, 30),     # 城市道路
            (50, 60),     # 城市快速路
            (99999, 90)   # 高速
        ]

    def calculate_distance_km(self, origin_lat, origin_lng, destination_lat, destination_lng):
        origin = (origin_lat, origin_lng)
        destination = (destination_lat, destination_lng)
        return geodesic(origin, destination).km

    def estimate_travel_time_minutes(self, origin_lat, origin_lng, destination_lat, destination_lng):
        distance_km = self.calculate_distance_km(
            origin_lat, origin_lng,
            destination_lat, destination_lng
        )

        remaining = distance_km
        last_upper = 0
        total_hours = 0

        # 累进制分段计算
        for upper, speed in self.speed_rules:
            if remaining <= 0:
                break

            segment_length = min(remaining, upper - last_upper)
            segment_hours = segment_length / speed
            total_hours += segment_hours

            remaining -= segment_length
            last_upper = upper

        return total_hours * 60, distance_km

    # 对外接口

    def get_route_estimate(self, origin_lat, origin_lng, destination_lat, destination_lng):
        minutes, dist = self.estimate_travel_time_minutes(
            origin_lat, origin_lng, destination_lat, destination_lng
        )
        return f"distance: {dist:.2f} km, estimated travel time: {minutes:.0f} min"
        
    
    def get_city_center_coords(self, city_name):
        result = city_center_coords.get(city_name.lower())
        if result:
            return f"longitude: {result['lon']}, latitude: {result['lat']}"
        else:
            return f"No city center coordinates found for the given city name: {city_name}."

    def get_date_after(self, date_str, days):
      """
      获取指定日期若干天后的日期
      :param date_str: 输入日期字符串，例如 "2025-03-01"
      :param days: 要增加的天数
      :return: 返回 YYYY-MM-DD 格式的字符串
      """
      days = int(days)
      date = datetime.strptime(date_str, "%Y-%m-%d")
      new_date = date + timedelta(days=days)
      return new_date.strftime("%Y-%m-%d")
    
# ------------------------------------
# 使用示例（写在同一个文件内可直接运行）
# ------------------------------------

if __name__ == "__main__":
    calc = GeneralTool()

    # 测试：任意两个点（示例为上海）
    result = calc.get_route_estimate(
        origin_lat=31.2304, origin_lng=121.4737,
        destination_lat=31.2243, destination_lng=121.4768
    )

    print(result)

    # 测试：获取城市中心坐标（示例为上海）
    result = calc.get_city_center_coords("Shanghai")
    print(result)

    # 测试：增加日期（示例为2025-03-01增加10天）
    result = calc.get_date_after("2025-03-01", 10)
    print(result)