import argparse
import os
import pandas as pd
import networkx as nx
from scipy.sparse import load_npz
from torch_geometric.utils import from_scipy_sparse_matrix

def split_topk_subgraphs(G, df, edges, base_path, dataset, topk, resolution, seed):
    """
    根据社区大小保留前 topk 个社区，分别输出子图文件。
    """
    # 检测社区
    communities = nx.community.louvain_communities(G, resolution=resolution, seed=seed)
    # 排序
    communities_sorted = sorted(communities, key=len, reverse=True)
    print(f"所有社区节点数量（从大到小）：")
    for idx, comm in enumerate(communities_sorted):
        print(f"社区{idx}：{len(comm)} 个节点")
    print(f"检测到 {len(communities_sorted)} 个社区，保留前 {topk} 个最大社区。")

    # 提取 topk 子图
    for i, comm in enumerate(communities_sorted[:topk], start=1):
        comm_nodes = sorted(comm)
        old_to_new = {old: new for new, old in enumerate(comm_nodes)}

        # 写节点特征
        sub_df = df.loc[comm_nodes].reset_index()
        out_feat = os.path.join(base_path, f'{dataset}_top{i}.csv')
        sub_df.to_csv(out_feat, index=False)

        # 写子图边
        sub_edges = [(old_to_new[u], old_to_new[v]) for u, v in edges if u in comm and v in comm]
        out_edge = os.path.join(base_path, f'{dataset}_top{i}_edges.txt')
        with open(out_edge, 'w') as fo:
            for u_new, v_new in sub_edges:
                fo.write(f'{u_new} {v_new}\n')

        print(f'Top{i} 子图: 新节点0~{len(comm_nodes)-1} ({len(comm_nodes)}个)，内部边{len(sub_edges)} 条 → 保存 {os.path.basename(out_feat)}, {os.path.basename(out_edge)}')

def split_two_subgraphs(G, df, edges, base_path, dataset, resolution, seed):
    """
    最大社区 vs 其余节点，分别输出两个子图文件。
    """
    communities = nx.community.louvain_communities(G, resolution=resolution, seed=seed)
    communities_sorted = sorted(communities, key=len, reverse=True)

    largest = set(communities_sorted[0])
    rest = set(G.nodes()) - largest
    print(f"最大社区节点数量: {len(largest)}; 其余节点数量: {len(rest)}")

    for i, comm in enumerate([largest, rest], start=1):
        comm_nodes = sorted(comm)
        old_to_new = {old: new for new, old in enumerate(comm_nodes)}
        # old_to_new = {old: old for new, old in enumerate(comm_nodes)}

        sub_df = df.loc[comm_nodes].reset_index()
        out_feat = os.path.join(base_path, f'{dataset}_{i}.csv')
        sub_df.to_csv(out_feat, index=False)

        sub_edges = [(old_to_new[u], old_to_new[v]) for u, v in edges if u in comm and v in comm]
        out_edge = os.path.join(base_path, f'{dataset}_{i}_edges.txt')
        with open(out_edge, 'w') as fo:
            for u_new, v_new in sub_edges:
                fo.write(f'{u_new} {v_new}\n')

        print(f'Split{i} 子图: 新节点0~{len(comm_nodes)-1} ({len(comm_nodes)}个)，内部边{len(sub_edges)} 条 → 保存 {os.path.basename(out_feat)}, {os.path.basename(out_edge)}')

def main():
    parser = argparse.ArgumentParser(description="Extract top-k communities into subgraphs")
    parser.add_argument('--dataset', type=str, required=True,
                        help="Base name of your dataset, e.g. if files are foo.csv and foo_edges.txt or foo_edges.npz, use --dataset foo")
    parser.add_argument('--resolution', type=float, default=0.8,
                        help="Louvain resolution parameter (越小社区越少)")
    parser.add_argument('--topk', type=int, default=3,
                        help="保留最大的 topk 个社区")
    parser.add_argument('--seed', type=int, default=42,
                        help="随机数种子，用于固定 Louvain 结果")
    parser.add_argument('--mode', choices=['topk','two'], default='two',
                        help="选择子图划分方式：topk（前 k 大社区）或 two（最大社区 vs 其余）")
    args = parser.parse_args()

    base_path = f'/home/disk2/lhr/fairDomainAdaption/mine/dataset/{args.dataset}'

    # 1. 读取节点特征，并根据情况处理 user id 列
    feat_csv = os.path.join(base_path, f'{args.dataset}.csv')
    df = pd.read_csv(feat_csv)
    # 若存在用户 ID 列，则设为 index，统一命名为 node_id
    if "user_id" in df.columns:
        df = df.set_index("user_id")
    df.index.name = 'user_id'

    # 2. 读取原始边列表（支持 .npz 或 .txt）
    edges = []
    npz_path = os.path.join(base_path, f'{args.dataset}_edges.npz')
    txt_path = os.path.join(base_path, f'{args.dataset}_edges.txt')
    if os.path.exists(npz_path):
        adj = load_npz(npz_path)
        edge_index, _ = from_scipy_sparse_matrix(adj)
        edge_np = edge_index.cpu().numpy()
        edges = list(zip(edge_np[0].tolist(), edge_np[1].tolist()))
    elif os.path.exists(txt_path):
        with open(txt_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 2:
                    continue
                # parts = line.strip().split("\t")
                # if len(parts) != 2:
                #     continue
                # u, v = int(float(parts[0])), int(float(parts[1]))
                u, v = int((parts[0])), int((parts[1]))
                edges.append((u, v))
    else:
        raise FileNotFoundError(f"Neither .npz nor .txt edge file found in {base_path}")

    # 3. 构建图并检测社区
    G = nx.Graph()
    G.add_edges_from(edges)

    # 根据模式分割子图
    if args.mode=='topk':
        split_topk_subgraphs(G, df, edges, base_path, args.dataset, args.topk, args.resolution, args.seed)
    else:
        split_two_subgraphs(G, df, edges, base_path, args.dataset, args.resolution, args.seed)

    # communities = nx.community.louvain_communities(
    #     G,
    #     resolution=args.resolution,
    #     seed=args.seed
    # )

    # # 4. 按大小排序并输出所有社区的节点数量
    # communities_sorted = sorted(communities, key=len, reverse=True)
    # print("所有社区节点数量（从大到小）：")
    # for idx, comm in enumerate(communities_sorted):
    #     print(f"社区{idx}：{len(comm)} 个节点")

    # # 5. 提取并保存 top-k 最大社区
    # top_communities = communities_sorted[:args.topk]
    # print(f"检测到 {len(communities)} 个社区，保留前 {args.topk} 个最大社区。")

    # for i, comm in enumerate(top_communities, start=1):
    #     comm_nodes = sorted(comm)
    #     # 建立 old->new 映射
    #     old_to_new = {old: new for new, old in enumerate(comm_nodes)}
        
    #     # —— 写子节点特征 CSV ——
    #     sub_df = df.loc[comm_nodes].reset_index()
    #     out_feat = os.path.join(base_path, f'{args.dataset}_{i}.csv')
    #     sub_df.to_csv(out_feat, index=False)

    #     # —— 写子图内部边列表 ——
    #     sub_edges = []
    #     for u, v in edges:
    #         if u in comm and v in comm:
    #             sub_edges.append((old_to_new[u], old_to_new[v]))
    #     out_edge = os.path.join(base_path, f'{args.dataset}_{i}_edges.txt')
    #     with open(out_edge, 'w') as fo:
    #         for u_new, v_new in sub_edges:
    #             fo.write(f'{u_new} {v_new}\n')

    #     print(f'社区{i}: 原节点 {len(comm_nodes)} 个，新节点编号0~{len(comm_nodes)-1}，内部边 {len(sub_edges)} 条 → 写入 {out_feat}, {out_edge}')

if __name__ == '__main__':
    main()