from torch_geometric.data import Data
import numpy as np
import torch
import networkx as nx
import pandas as pd


def nodes_feature(X:np.ndarray):
    result = []
    result.append(X.min(axis=0))
    result.append(X.max(axis=0))
    result.append(X.mean(axis=0))
    result.append(X.std(axis=0))
    result.append(np.percentile(X,25,axis=0))
    result.append(np.percentile(X,50,axis=0))
    result.append(np.percentile(X,75,axis=0))
    result = np.array(result)
    edges = np.array([np.linspace(0,X.shape[1]-1,X.shape[1]),
                      np.linspace(0,X.shape[1]-1,X.shape[1])])
    attr = torch.tensor([0 for i in range(X.shape[1])],dtype=torch.long)
    return result.transpose(), torch.tensor(edges, dtype=torch.long), attr


def get_graph(X,features):
    nodes, edges, attr = nodes_feature(X)
    y = torch.Tensor([i for i in range(len(features)-1)])
    g = Data(x=torch.tensor(nodes,dtype=torch.float), y=y, edge_index=edges, edge_type=attr)
    return g


def pyg2nx(g:Data):
    G = nx.DiGraph()
    nodes_feature_matrix = g.x
    # node_names = g.y

    for i in range(nodes_feature_matrix.shape[0]):
        G.add_nodes_from([i],node_feature=nodes_feature_matrix[i])

    edges = np.array(g.edge_index.T, dtype=int)#记录的是两节点的值
    types = np.array(g.edge_type).reshape(-1,1)
    for i in range(edges.shape[0]):
        G.add_weighted_edges_from([(edges[i][0],edges[i][1],types[i][0])])
    return G

def nx2pyg(G):
    feature_matrix = []
    y = []
    nodes_list = list(G.nodes)
    
    for i in range(len(G)):
        node = nodes_list[i]
        feature_matrix.append(np.array(G.nodes[node]['node_feature']))
        y.append(i)
    edges_data = list(G.edges.data())
    edges = []
    edge_type = []
    nodes_list = np.array(nodes_list)
    for i in range(len(edges_data)):
        edge = edges_data[i]
        ind1 = np.where(nodes_list==edge[0])[0][0]
        ind2 = np.where(nodes_list==edge[1])[0][0]
        edges.append([ind1,ind2])
        edge_type.append(edge[2]['weight'])
    g = Data(x=torch.tensor(np.array(feature_matrix),dtype=torch.float),
             y=torch.Tensor(y),
             edge_index=torch.Tensor(np.array(edges).T).to(dtype=torch.long),
             edge_type=torch.Tensor(np.array(edge_type)).to(dtype=torch.long))
    return g