import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from pyqtorch.matrices import generate_ising_from_graph, sum_N
import torch
from itertools import combinations
from torch_geometric.utils import to_dense_adj, from_networkx
from torch_sparse import SparseTensor



def open_graph(fname):
    N=28

    with open(fname) as f:
        lines = f.readlines()
    for i in range(0,1):
        start = i*(N+1)
        adj=[]
        for line in lines[start:start+N]:
            ent = []
            for bit in line:
                if bit!='\n':
                    ent.append(int(bit))
            adj.append(ent)
        A = np.asarray(adj)
        G = nx.from_numpy_array(A)
        
    print(nx.is_strongly_regular(G))
    return G.copy()

def compute_random_walk_xy_2_matrix(data):
    N = data.num_nodes
    edges = data.edge_index

    device='cpu'
    Adj = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes).squeeze(0).to(device)



    a = torch.tensor(list(combinations(np.arange(N), 2))).to(device)

    index = a.clone()
    b = a.unsqueeze(0).unsqueeze(2)
    a = a.unsqueeze(1).unsqueeze(1)
    Ne = a.size()[0]
    c = torch.cat([a.repeat(1, Ne, 1, 1), b.repeat(Ne, 1, 1, 1)], dim=2)
    c1 = torch.transpose(c.clone(), 2, 3)

    H = 0
    d = Adj[c1[:, :, :, 0], c1[:, :, :, 1]]
    e = (c1[:, :, :, 0] == c1[:, :, :, 1])*1
    f = d - e
    H += 1*((f[:, :, 0] * f[:, :, 1]) < -.5)

    d = c1[:, :, 0, 1].clone()

    c1[:, :, 0, 1]= c1[:, :, 1, 1].clone()
    c1[:, :, 1, 1]  = d
    del e
    del f
    torch.cuda.empty_cache()

    d = Adj[c1[:, :, :, 0], c1[:, :, :, 1]]
    e = (c1[:, :, :, 0] == c1[:, :, :, 1])*1
    f = d - e
    H += 1*((f[:, :, 0] * f[:, :, 1]) < -.5)
    H = H.float()
    M = torch.matmul(H, torch.diag(1/torch.maximum(torch.sum(H, dim=1), torch.ones(H.shape[0]))))

    return M, index.t()

def add_edge_attributes(G):
    data = from_networkx(G)
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                        sparse_sizes=(num_nodes, num_nodes),
                                        )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj_inv = adj * deg_inv.view(-1, 1)
    adj_inv = adj_inv.to_dense()

    out = adj_inv.clone()

    M, idx = compute_random_walk_xy_2_matrix(data)
    P = torch.ones(M.shape[0])#adj.to_dense()[idx[0], idx[1]]
    P = P / torch.sum(P)
    list_P = []
    M_pow = torch.matrix_power(M, 100)
    for k in range(0, 50):
        # P1 = torch.matmul(torch.matrix_exp(-1j * M*k), P.cfloat())
        P = torch.matmul(M_pow, P)
        list_P.append(P.clone())
    
    edge_values = dict()
    for u, (i, j) in enumerate(idx.t().numpy()):
        if (i, j) in list(G.edges()):
            edge_values[(i, j)] = P[u]*1e7
        if (j, i) in list(G.edges()):
            edge_values[(j, i)] = P[u]*1e7
    nx.set_edge_attributes(G, edge_values, 'p')
    return G.copy(), list_P, M


if __name__ == '__main__':
    # Load the graphs
    G1 = open_graph("SRG_16_6_2_2_1.txt")
    G2 = open_graph("SRG_16_6_2_2_2.txt")

    # Create a permutation of G1
    perm = np.random.permutation(16)
    perm_dict = dict()
    for i, j in enumerate(perm):
        perm_dict[i] = j
    G3 = G1.copy()
    nx.relabel_nodes(G3, perm_dict)

    # Compute distance matrix
    g1, p1, m1 = add_edge_attributes(G1)
    g2, p2, m2 = add_edge_attributes(G2)
    

    to_plot = []
    for k in range(50):
        norm = np.linalg.norm(np.sort(np.abs(p1[k])) - np.sort(np.abs(p2[k])))
        to_plot.append(norm)

    plt.plot(np.arange(1, 51) * 100, to_plot, marker='o', label='G1 - G2')

    g2, p2, m2 = add_edge_attributes(G3)

    to_plot = []
    for k in range(50):
        norm = np.linalg.norm(np.sort(np.abs(p1[k])) - np.sort(np.abs(p2[k])))
        to_plot.append(norm)

    plt.plot(np.arange(1, 51) * 100, to_plot, marker='o', label='G1 - G1 permuted')

    plt.xlabel('Number of steps', fontsize='x-large')
    plt.ylabel('L2 norm between sorted distance matrix', fontsize='x-large')
    plt.legend()
    plt.savefig("isomorphism.pdf", pad_inches='tight', bbox_inches=0)

    g1, p1, m1 = add_edge_attributes(G1)
    g2, p2, m2 = add_edge_attributes(G2)

    # Make a union graph so distances are hashed the same way
    G = nx.union(g1, g2, rename=('1-', '2-'))
    # GD-WL test
    result = nx.weisfeiler_lehman_subgraph_hashes(G, edge_attr='p')
