import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
import networkx as nx
import os
current_dir = os.path.dirname(os.path.abspath(__file__))

# ===== 1. 读取数据 =====
topo_data_path = os.path.join(current_dir, "../dataset/camels/camels_attributes_v2.0/", "camels_topo.txt")
df = pd.read_csv(topo_data_path,
                 sep=';',  # 匹配一个或多个空格作为分隔符
                 header=0,
                 engine='python',
                 index_col=0,        # 第一列作为索引 (station_id)
                 dtype=str           # 索引列和其他列先读成字符串，防止前导零丢失
                )

# ========= 2. 获取经纬度坐标 =========
coords = df[['gauge_lat', 'gauge_lon']].astype(float).values

# ========= 3. 构建 KNN 邻接矩阵 =========
coords_rad = np.radians(coords)
knn = NearestNeighbors(n_neighbors=2, metric='haversine')
knn.fit(coords_rad)
distances, indices = knn.kneighbors(coords_rad)

# 转换为公里
distances_km = distances * 6371

n = len(df)
adj_matrix = np.zeros((n, n))

# 填充邻接矩阵
for i, neighbors in enumerate(indices):
    for j in neighbors:
        if i != j:
            adj_matrix[i, j] = 1
            # adj_matrix[j, i] = 1  # 对称化

# ========= 4. 保存邻接矩阵为 CSV =========
# ==== 强制将 gauge_id 转为字符串 ====
df = df.reset_index()
station_ids = df['gauge_id'].astype(str).str.zfill(8)

adj_df = pd.DataFrame(adj_matrix, index=station_ids, columns=station_ids)
# ==== 导出时不转换格式 ====
adj_df.to_csv("camels_knn_adj_2.csv", index=True)
print("✅ 邻接矩阵已保存：camels_knn_adj_2.csv")

# # ==== 6. 构建 networkx 图对象 ====
# G = nx.Graph()
# # 添加节点
# for i, row in df.iterrows():
#     G.add_node(i, station_id=row['gauge_id'], lat=row['gauge_lat'], lon=row['gauge_lon'])

# # 添加边（从邻接矩阵）
# for i in range(n):
#     for j in range(i + 1, n):
#         if adj_matrix[i, j] > 0:
#             G.add_edge(i, j)

# # ==== 7. 保存 networkx 图 ====
# nx.write_graphml(G, "camel_knn_graph.graphml")      # XML格式，可Gephi等工具加载

# print("邻接矩阵和图文件已保存。")