import torch
import networkx as nx
import numpy as np

from utils_proof.str import *

def frr_to_graph(frr):
    statement_list = frr.split('\n')
    statement_list = [s for s in statement_list if len(s) and s[0] != '#' and '=' in s and '[' in s and ']' in s and '(' in s and ')' in s]
    statement_list = [s.split('#')[0].strip() for s in statement_list]
    statement_list = [s for s in statement_list if not is_variable(s)]
    # statement_list = merge_lines(statement_list)
    statement_list = [convert_function_calls(s) for s in statement_list]
    statement_list = convert_to_function(statement_list)
    statement_list = unpack_statements(statement_list)
    statement_list = convert_to_function(statement_list)
    G = nx.DiGraph()
    statement_list = eliminate_loop(statement_list)
    for s in statement_list:
        flag, output = extract_parts(s)
        if flag == 'init':
            G.add_node(output)
        elif flag == 'operation':
            target, description, sources = output
            if 'given' in description or 'triple' in description:
                continue
            if target in G.nodes():
                target = target + '_NeW'
            for source in sources:
                if is_num(source):
                    G.add_node(source)
                if not G.has_edge(source, target):
                    G.add_edge(source, target, label=description)
    return G

def get_weight_matrix_2(G, device, self_loop):
    A = nx.adjacency_matrix(G).toarray()
    if self_loop:
        A = A + np.eye(A.shape[0])
    A = A / (A.sum(axis=1, keepdims=True) + 1e-10)
    return torch.tensor(A, device=device, dtype=torch.float32)

def get_adjacency_matrix(G, device):
    A = nx.adjacency_matrix(G).toarray()
    return torch.tensor(A, device=device, dtype=torch.float32)

def get_ppr_graph(digraph):
    
    digraph = digraph.copy()
    nodes = list(digraph.nodes)
    zero_indegree_nodes = [node for node in digraph.nodes if digraph.in_degree(node) == 0]
    digraph.remove_edges_from(list(digraph.edges))
    for u in nodes:
        for v in zero_indegree_nodes:
            digraph.add_edge(u, v)
    
    return digraph

def max_singular_vectors(A):
    U, S, Vh = torch.linalg.svd(A)
    alpha_max = U[:, 0]
    beta_max = Vh[0, :]
    return alpha_max, beta_max