import torch
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, to_undirected, remove_isolated_nodes, dropout_adj, remove_self_loops, k_hop_subgraph, to_edge_index
from torch_geometric.utils.num_nodes import maybe_num_nodes
import copy
from torch_sparse import SparseTensor
from torch_geometric.loader import DataLoader, NeighborSampler


def add_remaining_selfloop_for_isolated_nodes(edge_index, num_nodes):
    num_nodes = max(maybe_num_nodes(edge_index), num_nodes)
    # only add self-loop on isolated nodes
    # edge_index, _ = remove_self_loops(edge_index)
    loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device)
    connected_nodes_indices = torch.cat([edge_index[0], edge_index[1]]).unique()
    mask = torch.ones(num_nodes, dtype=torch.bool)
    mask[connected_nodes_indices] = False
    loops_for_isolatd_nodes = loop_index[mask]
    loops_for_isolatd_nodes = loops_for_isolatd_nodes.unsqueeze(0).repeat(2, 1)
    edge_index = torch.cat([edge_index, loops_for_isolatd_nodes], dim=1)
    return edge_index
    
def collect_subgraphs(selected_id, graph, walk_steps=20, restart_ratio=0.5):
    graph  = copy.deepcopy(graph) # modified on the copy
    edge_index = graph.edge_index
    node_num = graph.x.shape[0]
    start_nodes = selected_id # only sampling selected nodes as subgraphs
    graph_num = start_nodes.shape[0]
    
    value = torch.arange(edge_index.size(1))

    if type(edge_index) == SparseTensor:
        adj_t = edge_index
    else:
        adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
                                    value=value,
                                    sparse_sizes=(node_num, node_num)).t()
    
    current_nodes = start_nodes.clone()
    history = start_nodes.clone().unsqueeze(0)
    signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0)
    for i in range(walk_steps):
        seed = torch.rand([graph_num])
        nei = adj_t.sample(1, current_nodes).squeeze()
        sign = seed < restart_ratio
        nei[sign] = start_nodes[sign]
        history = torch.cat((history, nei.unsqueeze(0)), dim=0)
        signs = torch.cat((signs, sign.unsqueeze(0)), dim=0)
        current_nodes = nei
    history = history.T
    signs = signs.T
    
    graph_list = []
    for i in range(graph_num):
        path = history[i]
        sign = signs[i]
        node_idx = path.unique()
        sources = path[:-1].numpy().tolist()
        targets = path[1:].numpy().tolist()
        sub_edges = torch.IntTensor([sources, targets]).long()
        sub_edges = sub_edges.T[~sign[1:]].T
        # undirectional
        if sub_edges.shape[1] != 0:
            sub_edges = to_undirected(sub_edges)
        view = adjust_idx(sub_edges, node_idx, graph, path[0].item())
        
        graph_list.append(view)
    return graph_list

def constrained_collect_subgraphs(selected_id, graph, walk_steps=20, restart_ratio=0.5):
    graph  = copy.deepcopy(graph) # modified on the copy
    edge_index = graph.edge_index
    node_num = graph.x.shape[0]
    start_nodes = selected_id # only sampling selected nodes as subgraphs
    graph_num = start_nodes.shape[0]
    
    value = torch.arange(edge_index.size(1))

    if type(edge_index) == SparseTensor:
        adj_t = edge_index
    else:
        adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
                                    value=value,
                                    sparse_sizes=(node_num, node_num)).t()

    two_hop_nodes = []
    for node in selected_id:
        try:
            subset, _, _, _ = k_hop_subgraph(node, 2, edge_index)
            if subset.numel()  == 0:  # 处理孤立节点
                subset = torch.tensor([node],  device=edge_index.device)
        except Exception as e:
            subset = torch.tensor([node],  device=edge_index.device) 
        two_hop_nodes.append(subset)
    
    current_nodes = start_nodes.clone()
    history = start_nodes.clone().unsqueeze(0)
    signs = torch.ones(graph_num, dtype=torch.bool).unsqueeze(0)
    for i in range(walk_steps):
        seed = torch.rand([graph_num])
        nei = adj_t.sample(1, current_nodes).squeeze()
        sign = seed < restart_ratio

        for j in range(graph_num):
            if nei[j] not in two_hop_nodes[j]:
                sign[j] = True

        nei[sign] = start_nodes[sign]
        history = torch.cat((history, nei.unsqueeze(0)), dim=0)
        signs = torch.cat((signs, sign.unsqueeze(0)), dim=0)
        current_nodes = nei
    history = history.T
    signs = signs.T
    
    graph_list = []
    for i in range(graph_num):
        path = history[i]
        sign = signs[i]
        node_idx = path.unique()
        sources = path[:-1].numpy().tolist()
        targets = path[1:].numpy().tolist()
        sub_edges = torch.IntTensor([sources, targets]).long()
        sub_edges = sub_edges.T[~sign[1:]].T
        # undirectional
        if sub_edges.shape[1] != 0:
            sub_edges = to_undirected(sub_edges)
        view = adjust_idx(sub_edges, node_idx, graph, path[0].item())
        
        graph_list.append(view)
    return graph_list

# def constrained_ego_graphs_sampler(selected_id, graph, nodeNum=40, restart_ratio=0.5):
#     graph = copy.deepcopy(graph)  # modified on the copy
#     edge_index = graph.edge_index
#     node_num = graph.x.shape[0]
#     start_nodes = selected_id  # only sampling selected nodes as subgraphs
#     graph_num = start_nodes.shape[0]

#     value = torch.arange(edge_index.size(1))

#     if type(edge_index) == SparseTensor:
#         adj_t = edge_index
#     else:
#         adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
#                              value=value,
#                              sparse_sizes=(node_num, node_num)).t()
    
#     # Directly get one-hop neighbors for each starting node.
#     nei = adj_t.sample(1, start_nodes).squeeze()
    
#     # For simplicity, we do not apply restarts here since we are only interested in direct neighbors.
#     history = torch.stack((start_nodes, nei), dim=0)
#     signs = torch.zeros_like(history, dtype=torch.bool)  # No restarts means no signs needed.

#     history = history.T
#     signs = signs.T

#     graph_list = []
#     for i in range(graph_num):
#         path = history[i]
#         sign = signs[i]
#         node_idx = path.unique()
#         sources = path[:-1].numpy().tolist()
#         targets = path[1:].numpy().tolist()
#         sub_edges = torch.IntTensor([sources, targets]).long()
#         sub_edges = sub_edges.T[~sign[1:]].T
#         # Assuming to_undirected and adjust_idx functions exist and work properly
#         if sub_edges.shape[1] != 0:
#             sub_edges = to_undirected(sub_edges)
#         view = adjust_idx(sub_edges, node_idx, graph, path[0].item())

#         graph_list.append(view)
#     return graph_list

def constrained_ego_graphs_sampler(selected_id, graph, nodeNum=40):
    graph  = copy.deepcopy(graph) # modified on the copy
    edge_index = graph.edge_index
    node_num = graph.x.shape[0]
    
    value = torch.arange(edge_index.size(1))

    if type(edge_index) == SparseTensor:
        adj_t = edge_index
    else:
        adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
                                    value=value,
                                    sparse_sizes=(node_num, node_num)).t()

    row, col = adj_t.storage.row(),  adj_t.storage.col()
    ego_graphs = []

    for index in selected_id:
        node_idx = [index.item()]
        sources = []
        targets = []
        queue = [index.item()]
        distances = {index.item():0}
        while queue and len(node_idx) < nodeNum:
            current = queue.pop(0)
            for neighbor_tensor in col[row == current]:
                neighbor = neighbor_tensor.item()
                if neighbor == current:
                    continue
                if not neighbor in distances:
                    if distances[current] == 2:
                        continue
                    distances[neighbor] = distances[current] + 1
                    queue.append(neighbor)
                    node_idx.append(neighbor)
                    if len(node_idx) == nodeNum:
                        break
                    if (distances[neighbor] == 2):
                        queue.append(current)
                        break
        
        for source_node in node_idx:
            for target_node in col[row == source_node]:
                if target_node in node_idx:
                    sources.append(source_node)
                    targets.append(target_node)

        sub_edges = torch.IntTensor([sources, targets]).long()
        # undirectional
        if sub_edges.shape[1] != 0:
            sub_edges = to_undirected(sub_edges)
        node_idx = torch.tensor(node_idx)
        view = adjust_idx(sub_edges, node_idx, graph, index.item())
        ego_graphs.append(view)

    return ego_graphs  

# def constrained_ego_graphs_sampler(selected_id, graph, nodeNum=40):
#     edge_index = graph.edge_index 
#     node_num = graph.x.shape[0] 
    
#     # 构建转置的稀疏矩阵（CSR格式）
#     adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], 
#                         sparse_sizes=(node_num, node_num)).t()
    
#     ego_graphs = []
#     for index in selected_id:
#         index = index.item() 
#         node_idx = [index]
#         distances = {index: 0}
#         queue = [index]
        
#         # BFS采样（限制两跳）
#         while queue and len(node_idx) < nodeNum:
#             current = queue.pop(0) 
#             row_start = adj_t.storage.rowptr()[current] 
#             row_end = adj_t.storage.rowptr()[current  + 1]
#             neighbors = adj_t.storage.col()[row_start:row_end] 
            
#             for neighbor in neighbors.tolist(): 
#                 if neighbor not in distances:
#                     distances[neighbor] = distances[current] + 1 
#                     if distances[neighbor] <= 2:  # 限制两跳
#                         node_idx.append(neighbor) 
#                         queue.append(neighbor) 
#                         if len(node_idx) == nodeNum:
#                             break
 
#         # 构建子图边
#         row, col = adj_t.storage.row(),  adj_t.storage.col() 
#         mask = torch.isin(row,  torch.tensor(node_idx))  & torch.isin(col,  torch.tensor(node_idx)) 
#         sub_edges = torch.stack([row[mask],  col[mask]])
 
#         # 处理无向图 
#         if sub_edges.size(1)  > 0:
#             sub_edges = to_undirected(sub_edges)
        
#         # 调整索引 
#         node_idx = torch.tensor(node_idx) 
#         view = adjust_idx(sub_edges, node_idx, graph, index)
#         ego_graphs.append(view) 
 
#     return ego_graphs
        

        
                


        
def adjust_idx(edge_index, node_idx, full_g, center_idx):
    '''re-index the nodes and edge index

    In the subgraphs, some nodes are droppped. We need to change the node index in edge_index in order to corresponds 
    nodes' index to edge index
    '''
    node_idx_map = {j : i for i, j in enumerate(node_idx.numpy().tolist())}
    sources_idx = list(map(node_idx_map.get, edge_index[0].numpy().tolist()))
    target_idx = list(map(node_idx_map.get, edge_index[1].numpy().tolist()))

    edge_index = torch.IntTensor([sources_idx, target_idx]).long()
    # x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], attention_mask=full_g.attention_mask[node_idx], center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx])
    x_view = Data(edge_index=edge_index, x=full_g.x[node_idx], center=node_idx_map[center_idx], original_idx=node_idx, y=full_g.y[center_idx], root_n_index=node_idx_map[center_idx])
    return x_view

def ego_graphs_sampler(node_idx, data, hop=2, sparse=False):
    ego_graphs = []
    if sparse:
        edge_index, _ = to_edge_index(data.edge_index)
    else:
        edge_index  = data.edge_index

    for idx in node_idx:
        subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph([idx], hop, edge_index, relabel_nodes=True)
        sub_edge_index = to_undirected(sub_edge_index)
        sub_x = data.x[subset]
        #sub_attention_mask = data.attention_mask[subset]
        # center_idx = subset[mapping].item() # node idx in the original graph, use idx instead
        #g = Data(x=sub_x, edge_index=sub_edge_index, attention_mask=sub_attention_mask, root_n_index=mapping, y=data.y[idx], original_idx=subset) # note: there we use root_n_index to record the index of target node, because `PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index`
        g = Data(x=sub_x, edge_index=sub_edge_index, center=mapping, root_n_index=mapping, y=data.y[idx], original_idx=subset)
        ego_graphs.append(g)
    return ego_graphs