# This implementation is based on https://github.com/weihua916/powerful-gnns and https://github.com/chrsmrrs/k-gnn/tree/master/examples
# Datasets are implemented based on the description in the corresonding papers (see the paper for references)
import numpy as np
import networkx as nx
import torch
from torch_geometric.data import  Data
from torch_geometric.utils import degree
torch.set_printoptions(profile="full")
import torch as th
#%%
# Synthetic datasets

class SymmetrySet:
    def __init__(self):
        self.hidden_units = 0
        self.num_classes = 0
        self.num_features = 0
        self.num_nodes = 0

    def addports(self, data):
        data.ports = torch.zeros(data.num_edges, 1)
        degs = degree(data.edge_index[0], data.num_nodes, dtype=torch.long) # out degree of all nodes
        for n in range(data.num_nodes):
            deg = degs[n]
            ports = np.random.permutation(int(deg))
            for i, neighbor in enumerate(data.edge_index[1][data.edge_index[0]==n]):
                nb = int(neighbor)
                data.ports[torch.logical_and(data.edge_index[0]==n, data.edge_index[1]==nb), 0] = float(ports[i])
        return data

    def makefeatures(self, data):
        data.x = torch.ones((data.num_nodes, 1))
        data.id = torch.tensor(np.random.permutation(np.arange(data.num_nodes))).unsqueeze(1)
        return data

    def makedata(self):
        pass

class FourCycles(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.p = 4
        self.hidden_units = 16
        self.num_classes = 2
        self.num_features = 1
        self.num_nodes = 4 * self.p
        self.graph_class = True

    def gen_graph(self, p):
        edge_index = None
        for i in range(p):
            e = torch.tensor([[i, p + i, 2 * p + i, 3 * p + i], [2 * p + i, 3 * p + i, i, p + i]], dtype=torch.long)
            if edge_index is None:
                edge_index = e
            else:
                edge_index = torch.cat([edge_index, e], dim=-1)
        top = np.zeros((p * p,))
        perm = np.random.permutation(range(p))
        for i, t in enumerate(perm):
            top[i * p + t] = 1
        bottom = np.zeros((p * p,))
        perm = np.random.permutation(range(p))
        for i, t in enumerate(perm):
            bottom[i * p + t] = 1
        for i, bit in enumerate(top):
            if bit:
                e = torch.tensor([[i // p, p + i % p], [p + i % p, i // p]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
        for i, bit in enumerate(bottom):
            if bit:
                e = torch.tensor([[2 * p + i // p, 3 * p + i % p], [3 * p + i % p, 2 * p + i // p]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
        return Data(edge_index=edge_index, num_nodes=self.num_nodes), any(np.logical_and(top, bottom))

    def makedata(self):
        size = 25
        p = self.p
        trues = []
        falses = []
        while len(trues) < size or len(falses) < size:
            data, label = self.gen_graph(p)
            data = self.makefeatures(data)
            data = self.addports(data)
            data.y = label
            if label and len(trues) < size:
                trues.append(data)
            elif not label and len(falses) < size:
                falses.append(data)
        return trues + falses

class SkipCircles(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.hidden_units = 32
        self.num_classes = 10 # num skips
        self.num_features = 1
        self.num_nodes = 41
        self.graph_class = True
        self.makedata()

    def makedata(self):
        size=self.num_nodes
        skips = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16]
        graphs = []
        for s, skip in enumerate(skips):
            edge_index = torch.tensor([[0, size-1], [size-1, 0]], dtype=torch.long)
            for i in range(size - 1):
                e = torch.tensor([[i, i+1], [i+1, i]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
            for i in range(size):
                e = torch.tensor([[i, i], [(i - skip) % size, (i + skip) % size]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
            data = Data(edge_index=edge_index, num_nodes=self.num_nodes)
            data = self.makefeatures(data)
            data = self.addports(data)
            data.y = torch.tensor(s)
            graphs.append(data)

        return graphs

def get_adjacency_matrix(edge_index, num_nodes, dtype=th.float64):
    A = th.zeros((num_nodes, num_nodes))
    for edge_k in range(edge_index.shape[-1]):
        node_i = edge_index[0, edge_k]
        node_j = edge_index[1, edge_k]
        A[node_i, node_j] = 1
    return A

#%%
visu = False
toy_name = '4cycles'

def get_graph_and_pos(C,nx_mode='kawai',seed=0, scale=1, k=0.1):
    G = nx.Graph()    
    n=C.shape[0]
    for i in range(n):
        G.add_node(i)
    for i in range(n):
        for j in range(n):
            if C[i,j]>0:
                G.add_edge(i,j,weight = C[i,j])

    if nx_mode=='spring':
        if k is None:
            pos=nx.spring_layout(G,seed=seed, scale = scale)
        else:
            pos=nx.spring_layout(G,seed=seed, k=k, scale = scale)

    elif nx_mode =='kawai':
        pos= nx.kamada_kawai_layout(G, scale=10)  
    elif nx_mode == 'circular':
        pos = nx.circular_layout(G, scale=1)
    elif nx_mode =='planar':
        pos =nx.planar_layout(G) 
    elif nx_mode =='spectral':
        pos=nx.spectral_layout(G)
    elif nx_mode == 'multi':
        pos= nx.multipartite_layout(G)
    #elif nx_mode == 'neato':

    #   pos =  nx.nx_agraph.graphviz_layout(G, prog='neato')
    pos = [pos[v] for v in range(n)]
    return G, np.vstack(pos)
    

def plot_attributed_graph(C, pos, h,ecolor='k',ealpha=0.2,ncolor='C0',nalpha=1,nscale=50):
    n=C.shape[0]
    Cmax=C.max()
    for i in range(n):
        for j in range(n):
            if C[i,j]>0 and h[i]>0 and h[j]>0:
                pl.plot([pos[i,0],pos[j,0]],[pos[i,1],pos[j,1]],color=ecolor,alpha=ealpha*C[i,j]/Cmax, zorder=1)
    pl.scatter(pos[:,0],pos[:,1],s=h*n*nscale,c=ncolor,alpha=nalpha, zorder=2,edgecolor='k',vmin=0,vmax=9,cmap='tab10')

if visu:
    import pylab as pl
    np.random.seed(0)
    th.manual_seed(0)
    
    
    if toy_name =='4cycles':
        dataset = FourCycles()
    elif toy_name =='skipcircles':
        dataset = SkipCircles()
    n = dataset.num_nodes
    print(f'Number of nodes: {n}')
    graphs = dataset.makedata()
    degs = []
    for g in graphs:
        deg = degree(g.edge_index[0], g.num_nodes, dtype=torch.long)
        degs.append(deg.max())
    print(f'Mean Degree: {torch.stack(degs).float().mean()}')
    print(f'Max Degree: {torch.stack(degs).max()}')
    print(f'Min Degree: {torch.stack(degs).min()}')
    print(f'Number of graphs: {len(dataset.makedata())}')
        
    num_features = dataset.num_features
    
    adj_graphs = [get_adjacency_matrix(g.edge_index, g.num_nodes) for g in graphs]
    y = np.array([g.y for g in graphs], dtype=np.int64)
    """
    for i in range(0,10):
        A = adj_graphs[i].numpy()
        G, pos= get_graph_and_pos(A, nx_mode='circular')
        h = np.ones(A.shape[0])/A.shape[0]
        pl.figure(i+1, (8,4))
        pl.clf()
        pl.subplot(121)
        pl.title('graph %s / y =%s'%(i, y[i]))
        plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
        pl.axis('off')
        pl.axis('equal')
        pl.subplot(122)
        pl.title('adjacency %s / y =%s'%(i, y[i]))
        pl.imshow(A);pl.colorbar()
        pl.axis('off')
        pl.axis('equal')
        pl.show()
    """
    fig_count = 1
    if toy_name == '4cycles':
        l = [[30,44], [12,2]]
    for label in np.unique(y):
        #idx_ = np.argwhere(y==label)[:, 0]
        if toy_name == '4cycles':
            #samples = np.random.choice(idx_, size=25, replace=False)
            samples = l[label]
            for i in samples:
                A = adj_graphs[i].numpy()
                A += np.eye(A.shape[0])
                G, pos= get_graph_and_pos(A, nx_mode='spring')
                h = np.ones(A.shape[0])/A.shape[0]
                pl.figure(fig_count, (3,3))
                pl.clf()
                pl.title('4cycles: sample (y=%s)'%(label))
                pl.tight_layout()
                pl.axis('off')
                plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
                pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
                #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
                pl.show()
                fig_count += 1
#%% visu 4cycles
visu=False
if visu:
    np.random.seed(0)
    th.manual_seed(0)
    dataset = FourCycles()
    n = dataset.num_nodes
    print(f'Number of nodes: {n}')
    graphs = dataset.makedata()
    degs = []
    for g in graphs:
        deg = degree(g.edge_index[0], g.num_nodes, dtype=torch.long)
        degs.append(deg.max())
    print(f'Mean Degree: {torch.stack(degs).float().mean()}')
    print(f'Max Degree: {torch.stack(degs).max()}')
    print(f'Min Degree: {torch.stack(degs).min()}')
    print(f'Number of graphs: {len(dataset.makedata())}')
        
    num_features = dataset.num_features
    
    adj_graphs = [get_adjacency_matrix(g.edge_index, g.num_nodes) for g in graphs]
    y = np.array([g.y for g in graphs], dtype=np.int64)

    label=0
    i=30
    A = adj_graphs[i].numpy()
    G, pos= get_graph_and_pos(A, nx_mode='kawai')#, seed=3)
    h = np.ones(A.shape[0])/A.shape[0]
    pl.figure(fig_count, (3,3))
    pl.clf()
    pl.title('4-CYCLES: sample (y=%s)'%(label))
    pl.tight_layout()
    pl.axis('off')
    plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
    pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
    #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
    pl.show()
    fig_count += 1

    label = 0
    i=44
    A = adj_graphs[i].numpy()
    G, pos= get_graph_and_pos(A, nx_mode='spring', seed=1, scale=50)#, seed=3)
    h = np.ones(A.shape[0])/A.shape[0]
    pl.figure(fig_count, (3,3))
    pl.clf()
    pl.title('4-CYCLES: sample (y=%s)'%(label))
    pl.tight_layout()
    pl.axis('off')
    plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
    pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
    #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
    pl.show()
    fig_count += 1
    
    label = 1
    i=12
    A = adj_graphs[i].numpy()
    G, pos= get_graph_and_pos(A, nx_mode='spring', seed=2, scale=1, k=0.1)#, seed=3)
    h = np.ones(A.shape[0])/A.shape[0]
    pl.figure(fig_count, (3,3))
    pl.clf()
    pl.title('4-CYCLES: sample (y=%s)'%(label))
    pl.tight_layout()
    pl.axis('off')
    plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
    pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
    #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
    pl.show()
    fig_count += 1
    
    label = 1
    i=2
    A = adj_graphs[i].numpy()
    G, pos= get_graph_and_pos(A, nx_mode='spring', seed=3, scale=1, k=None)#, seed=3)
    h = np.ones(A.shape[0])/A.shape[0]
    pl.figure(fig_count, (3,3))
    pl.clf()
    pl.title('4-CYCLES: sample (y=%s)'%(label))
    pl.tight_layout()
    pl.axis('off')
    plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
    pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
    #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
    pl.show()
    fig_count += 1
#%% VISU FOR THE DATASET SKIP-CIRCLES


visu = False
toy_name = 'skipcircles'

def get_graph_and_pos(C,nx_mode='kawai',seed=0, scale=1, k=0.1):
    G = nx.Graph()    
    n=C.shape[0]
    for i in range(n):
        G.add_node(i)
    for i in range(n):
        for j in range(n):
            if C[i,j]>0:
                G.add_edge(i,j,weight = C[i,j])

    if nx_mode=='spring':
        if k is None:
            pos=nx.spring_layout(G,seed=seed, scale = scale)
        else:
            pos=nx.spring_layout(G,seed=seed, k=k, scale = scale)

    elif nx_mode =='kawai':
        pos= nx.kamada_kawai_layout(G, scale=10)  
    elif nx_mode == 'circular':
        pos = nx.circular_layout(G, scale=10)
    elif nx_mode =='planar':
        pos =nx.planar_layout(G) 
    elif nx_mode =='spectral':
        pos=nx.spectral_layout(G)
    elif nx_mode == 'multi':
        pos= nx.multipartite_layout(G)
    #elif nx_mode == 'neato':

    #   pos =  nx.nx_agraph.graphviz_layout(G, prog='neato')
    pos = [pos[v] for v in range(n)]
    return G, np.vstack(pos)
    

def plot_attributed_graph(C, pos, h,ecolor='k',ealpha=0.2,ncolor='C0',nalpha=1,nscale=30):
    n=C.shape[0]
    Cmax=C.max()
    for i in range(n):
        for j in range(n):
            if C[i,j]>0 and h[i]>0 and h[j]>0:
                pl.plot([pos[i,0],pos[j,0]],[pos[i,1],pos[j,1]],color=ecolor,alpha=ealpha*C[i,j]/Cmax, zorder=1)
    pl.scatter(pos[:,0],pos[:,1],s=h*n*nscale,c=ncolor,alpha=nalpha, zorder=2,edgecolor='k',vmin=0,vmax=9,cmap='tab10')

if visu:
    import pylab as pl
    np.random.seed(0)
    th.manual_seed(0)
    
    
    if toy_name =='4cycles':
        dataset = FourCycles()
    elif toy_name =='skipcircles':
        dataset = SkipCircles()
    n = dataset.num_nodes
    print(f'Number of nodes: {n}')
    graphs = dataset.makedata()
    degs = []
    for g in graphs:
        deg = degree(g.edge_index[0], g.num_nodes, dtype=torch.long)
        degs.append(deg.max())
    print(f'Mean Degree: {torch.stack(degs).float().mean()}')
    print(f'Max Degree: {torch.stack(degs).max()}')
    print(f'Min Degree: {torch.stack(degs).min()}')
    print(f'Number of graphs: {len(dataset.makedata())}')
        
    num_features = dataset.num_features
    
    adj_graphs = [get_adjacency_matrix(g.edge_index, g.num_nodes) for g in graphs]
    y = np.array([g.y for g in graphs], dtype=np.int64)
    """
    for i in range(0,10):
        A = adj_graphs[i].numpy()
        G, pos= get_graph_and_pos(A, nx_mode='circular')
        h = np.ones(A.shape[0])/A.shape[0]
        pl.figure(i+1, (8,4))
        pl.clf()
        pl.subplot(121)
        pl.title('graph %s / y =%s'%(i, y[i]))
        plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
        pl.axis('off')
        pl.axis('equal')
        pl.subplot(122)
        pl.title('adjacency %s / y =%s'%(i, y[i]))
        pl.imshow(A);pl.colorbar()
        pl.axis('off')
        pl.axis('equal')
        pl.show()
    """
    fig_count = 1
    if toy_name == '4cycles':
        l = [[30,44], [12,2]]
    
    for label in np.unique(y):
        #idx_ = np.argwhere(y==label)[:, 0]
        if toy_name == '4cycles':
            #samples = np.random.choice(idx_, size=25, replace=False)
            samples = l[label]
            for i in samples:
                A = adj_graphs[i].numpy()
                A += np.eye(A.shape[0])
                G, pos= get_graph_and_pos(A, nx_mode='spring')
                h = np.ones(A.shape[0])/A.shape[0]
                pl.figure(fig_count, (3,3))
                pl.clf()
                pl.title('4cycles: sample (y=%s)'%(label))
                pl.tight_layout()
                pl.axis('off')
                plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
                #pl.savefig('./imgs/4cycles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
                #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
                pl.show()
                fig_count += 1
                
        else:
            idx_ = np.argwhere(y==label)[:, 0]
            samples = np.random.choice(idx_, size=1)
            for i in samples:
                A = adj_graphs[i].numpy()
                A += np.eye(A.shape[0])
                G, pos= get_graph_and_pos(A, nx_mode='circular')
                h = np.ones(A.shape[0])/A.shape[0]
                pl.figure(fig_count, (3,3))
                pl.clf()
                pl.title('SKIP-CIRCLES: sample (y=%s)'%(label))
                pl.tight_layout()
                pl.axis('off')
                plot_attributed_graph(A, pos, h, ncolor='C0', ealpha=0.5)
                pl.savefig('./imgs/skipcircles_sample%s_label%s.pdf'%(i, label), bbox_inches='tight')
                #nx.draw(G=G, pos=pos, node_size=50, node_color='C0', edge_color='k',vmin=0,vmax=9,cmap='tab10')
                pl.show()
                fig_count += 1

