import numpy as np
import numba as nb

NODE_FEATURES_OFFSET = 128
EDGE_FEATURES_OFFSET = 8

@nb.njit(nogil=True)
def floyd_warshall(A):
    n = A.shape[0]
    D = np.zeros((n,n), dtype=np.int16)
    
    for i in range(n):
        for j in range(n):
            if i == j:
                pass
            elif A[i,j] == 0:
                D[i,j] = 510
            else:
                D[i,j] = 1
    
    for k in range(n):
        for i in range(n):
            for j in range(n):
                old_dist = D[i,j]
                new_dist = D[i,k] + D[k,j]
                if new_dist < old_dist:
                    D[i,j] = new_dist
    return D

@nb.njit(nogil=True)
def preprocess_data(num_nodes, edges, node_feats, edge_feats):
    node_feats = node_feats + np.arange(1,node_feats.shape[-1]*NODE_FEATURES_OFFSET+1,
                                            NODE_FEATURES_OFFSET,dtype=np.int16)
    edge_feats = edge_feats + np.arange(1,edge_feats.shape[-1]*EDGE_FEATURES_OFFSET+1,
                                            EDGE_FEATURES_OFFSET,dtype=np.int16)
    
    A = np.zeros((num_nodes,num_nodes),dtype=np.int16)
    E = np.zeros((num_nodes,num_nodes,edge_feats.shape[-1]),dtype=np.int16)

    for k in range(edges.shape[0]):
        i,j = edges[k,0], edges[k,1]
        A[i,j] = 1
        E[i,j] = edge_feats[k]
    
    D = floyd_warshall(A)
    return node_feats, D, E

def find_torsion_chains(num_nodes, edges):
    # 初始化 adjacency_list 为每个节点空的 NumPy 数组列表
    adjacency_list = [set() for _ in range(num_nodes)]

    for start, end in edges:
        adjacency_list[start].add(end)
        adjacency_list[end].add(start)

    # 转换集合为列表，便于后续操作
    adjacency_list = [list(neighbors) for neighbors in adjacency_list]

    torsion_indices = []

    # 找到所有相邻三胞胎
    for i in range(num_nodes):
        for j in adjacency_list[i]:
            for k in adjacency_list[j]:
                if k == i:
                    continue
                for l in adjacency_list[k]:
                    if l == j or l == i or [l, k, j, i] in torsion_indices:
                        continue
                    torsion_indices.append([i, j, k, l])
                                                        
    if not torsion_indices:
        torsion_indices.append([0, 0, 0, 0])
    if len(torsion_indices) > 200:
        torsion_indices = torsion_indices[:200]
    torsion_indices = np.array(torsion_indices, dtype=np.int16)

    return torsion_indices

def find_adjacent_triplets(num_nodes, edges):
    # 使用邻接表表示每个节点的邻居节点
    adjacency_list = [set() for _ in range(num_nodes)]

    for start, end in edges:
        adjacency_list[start].add(end)
        adjacency_list[end].add(start)

    # 转换集合为列表，便于后续操作
    adjacency_list = [list(neighbors) for neighbors in adjacency_list]

    triplets = []  # 用于存储三胞胎的列表

    # 找到所有相邻三胞胎
    for i in range(num_nodes):
        for j in adjacency_list[i]:
            for k in adjacency_list[j]:
                if k != i and (k, j, i) not in triplets:  
                    # 使用集合防止二重添加 triplet
                    triplets.append((i, j, k))

    if not triplets:
        triplets.append((0, 0, 0))
    triplets = np.array(triplets, dtype=np.int16)
    
    return triplets

class AddStructuralData:
    def __init__(self,
                 num_nodes_key            = 'num_nodes',
                 node_features_key        = 'node_features',
                 edges_key                = 'edges',
                 edge_features_key        = 'edge_features',
                 distance_matrix_key      = 'distance_matrix',
                 feature_matrix_key       = 'feature_matrix',
                 angle_key = 'angle_indices',
                 torsion_key = 'torsion_indices'):
        self.num_nodes_key            = num_nodes_key
        self.node_features_key        = node_features_key
        self.edges_key                = edges_key
        self.edge_features_key        = edge_features_key
        self.distance_matrix_key      = distance_matrix_key      
        self.feature_matrix_key       = feature_matrix_key
        self.angle_key       = angle_key   
        self.torsion_key       = torsion_key 
        
    def __call__(self, item: dict):
        num_nodes = int(item[self.num_nodes_key])
        edges = item.pop(self.edges_key)
        node_feats = item.pop(self.node_features_key)
        edge_feats = item.pop(self.edge_features_key)
        
        node_feats, dist_mat, edge_feats_mat = preprocess_data(num_nodes, edges,
                                                               node_feats, edge_feats)
        torsion_indices = find_torsion_chains(num_nodes, edges)
        angle_indices = find_adjacent_triplets(num_nodes, edges)
        # print(f'num_nodes {num_nodes} torsion size {torsion_indices.shape} angle size {angle_indices.shape}')

        item[self.angle_key] = angle_indices
        item[self.torsion_key] = torsion_indices
        item[self.node_features_key] = node_feats
        item[self.distance_matrix_key] = dist_mat
        item[self.feature_matrix_key] = edge_feats_mat
        
        return item
