# hpc_utils.py
import torch
import numpy as np
from scipy.sparse.linalg import eigs
from collections import defaultdict

def create_line_graph(edge_list, num_nodes_a, num_nodes_b):

    edge_to_idx = { (u, v): i for i, (u, v, s) in enumerate(edge_list) }
    idx_to_edge = { i: (u, v) for i, (u, v, s) in enumerate(edge_list) }
    num_line_nodes = len(edge_list)
    
    line_graph_adj = defaultdict(list)
    
    adj_a = defaultdict(list)
    adj_b = defaultdict(list)
    for u, v, s in edge_list:
        adj_a[u].append(v)
        adj_b[v].append(u)

    for i in range(num_line_nodes):
        u1, v1 = idx_to_edge[i]
        
        for v2 in adj_a[u1]:
            if v1 != v2:
                j = edge_to_idx.get((u1, v2))
                if j is not None:
                    line_graph_adj[i].append(j)

        for u2 in adj_b[v1]:
            if u1 != u2:
                j = edge_to_idx.get((u2, v1))
                if j is not None:
                    line_graph_adj[i].append(j)

    adj_matrix = torch.zeros((num_line_nodes, num_line_nodes))
    for i, neighbors in line_graph_adj.items():
        for j in neighbors:
            adj_matrix[i, j] = 1
            
    edge_signs = torch.tensor([s for _, _, s in edge_list], dtype=torch.float)

    return adj_matrix, edge_signs, edge_to_idx

def get_relational_spectral_encoding(line_adj, edge_signs, k=8):
    num_nodes = line_adj.shape[0]
    sign_matrix = torch.outer(edge_signs, edge_signs)
    A_s = line_adj * sign_matrix  
    
    D_abs_s = torch.diag(torch.sum(torch.abs(A_s), dim=1)) 
    L_s = D_abs_s - A_s 
    
    _, eigenvectors = eigs(L_s.numpy(), k=k, which='SR') 
    eigenvectors = torch.from_numpy(eigenvectors.real).float()
    
    P_s = torch.mm(eigenvectors, eigenvectors.T) # [cite: 108]
    return P_s
    
def get_topological_motif_encoding(line_adj, max_path_len=3):
    num_nodes = line_adj.shape[0]
    path_matrices = []
    
    current_power = line_adj
    for _ in range(max_path_len):
        path_matrices.append(current_power)
        current_power = torch.mm(current_power, line_adj)
        
    path_indicator = torch.stack(path_matrices, dim=-1) > 0
    return path_indicator.float()