import torch
import numpy as np
from torch_geometric.utils import degree
from torch_geometric.utils.convert import from_networkx, to_networkx
from torch_geometric.data import DataLoader, Data
import networkx as nx
from .util import load_data

'''Dataset generators for the expressivity experiments, code adapted from:
https://github.com/KarolisMart/DropGNN/blob/main/gin-synthetic.py'''

class SymmetrySet:
    def __init__(self):
        self.hidden_units = 0
        self.num_classes = 0
        self.num_features = 0
        self.num_nodes = 0

    def addports(self, data):
        data.ports = torch.zeros(data.num_edges, 1)
        degs = degree(data.edge_index[0], data.num_nodes, dtype=torch.long) # out degree of all nodes
        for n in range(data.num_nodes):
            deg = degs[n]
            ports = np.random.permutation(int(deg))
            for i, neighbor in enumerate(data.edge_index[1][data.edge_index[0]==n]):
                nb = int(neighbor)
                data.ports[torch.logical_and(data.edge_index[0]==n, data.edge_index[1]==nb), 0] = float(ports[i])
        return data

    def makefeatures(self, data):
        data.x = torch.ones((data.num_nodes, 1))
        data.id = torch.tensor(np.random.permutation(np.arange(data.num_nodes))).unsqueeze(1)
        return data

    def makedata(self):
        pass

class LimitsOne(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.hidden_units = 16
        self.num_classes = 2
        self.num_features = 4
        self.num_nodes = 8
        self.graph_class = False

    def makedata(self, size=1):
        n_nodes = 16 # There are two connected components, each with 8 nodes
        
        ports = [1,1,2,2] * 8
        colors = [0, 1, 2, 3] * 4

        y = torch.tensor([0]* 8 + [1] * 8)
        edge_index = torch.tensor([[0,1,1,2, 2,3,3,0, 4,5,5,6, 6,7,7,4, 8,9,9,10,10,11,11,12,12,13,13,14,14,15,15,8], [1,0,2,1, 3,2,0,3, 5,4,6,5, 7,6,4,7, 9,8,10,9,11,10,12,11,13,12,14,13,15,14,8,15]], dtype=torch.long)
        x = torch.zeros((n_nodes, 4))
        x[range(n_nodes), colors] = 1
        
        data = Data(x=x, edge_index=edge_index, y=y)
        data.id = torch.tensor(np.random.permutation(np.arange(n_nodes))).unsqueeze(1)
        data.ports = torch.tensor(ports).unsqueeze(1)
        data.root = torch.tensor([8])
        return [data for i in range(size)]

class LimitsTwo(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.hidden_units = 16
        self.num_classes = 2
        self.num_features = 4
        self.num_nodes = 8
        self.graph_class = False

    def makedata(self, size=1):
        n_nodes = 16 # There are two connected components, each with 8 nodes

        ports = ([1,1,2,2,1,1,2,2] * 2 + [3,3,3,3]) * 2
        colors = [0, 1, 2, 3] * 4
        y = torch.tensor([0] * 8 + [1] * 8)
        edge_index = torch.tensor([[0,1,1,2,2,3,3,0, 4,5,5,6,6,7,7,4, 1,3,5,7, 8,9,9,10,10,11,11,8, 12,13,13,14,14,15,15,12, 9,15,11,13], 
                                   [1,0,2,1,3,2,0,3, 5,4,6,5,7,6,4,7, 3,1,7,5, 9,8,10,9,11,10,8,11, 13,12,14,13,15,14,12,15, 15,9,13,11]], dtype=torch.long)
        x = torch.zeros((n_nodes, 4))
        x[range(n_nodes), colors] = 1
        data = Data(x=x, edge_index=edge_index, y=y)
        data.id = torch.tensor(np.random.permutation(np.arange(n_nodes))).unsqueeze(1)
        data.ports = torch.tensor(ports).unsqueeze(1)
        return [data for i in range(size)]

class Triangles(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.hidden_units = 16
        self.num_classes = 2
        self.num_features = 1
        self.num_nodes = 60
        self.graph_class = False

    def makedata(self, path):
        dataset = load_data(path)
        return dataset

class LCC(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.hidden_units = 16
        self.num_classes = 3
        self.num_features = 1
        self.num_nodes = 10
        self.graph_class = False

    def makedata(self, path):
        dataset = load_data(path)
        return dataset

class FourCycles(SymmetrySet):
    def __init__(self):
        super().__init__()
        self.p = 4
        self.hidden_units = 16
        self.num_classes = 2
        self.num_features = 1
        self.num_nodes = 4 * self.p
        self.graph_class = True

    def gen_graph(self, p):
        edge_index = None
        for i in range(p):
            e = torch.tensor([[i, p + i, 2 * p + i, 3 * p + i], [2 * p + i, 3 * p + i, i, p + i]], dtype=torch.long)
            if edge_index is None:
                edge_index = e
            else:
                edge_index = torch.cat([edge_index, e], dim=-1)
        top = np.zeros((p * p,))
        perm = np.random.permutation(range(p))
        perm = np.random.randint(0,p, size=p)
        for i, t in enumerate(perm):
            top[i * p + t] = 1
        #perm = np.random.randint(0,p, size=p)
        #for i, t in enumerate(perm):
        #    top[i * p + t] = 1
        bottom = np.zeros((p * p,))
        perm = np.random.permutation(range(p))
        perm = np.random.randint(0,p, size=p)
        for i, t in enumerate(perm):
            bottom[i * p + t] = 1
        for i, bit in enumerate(top):
            if bit:
                e = torch.tensor([[i // p, p + i % p], [p + i % p, i // p]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
        for i, bit in enumerate(bottom):
            if bit:
                e = torch.tensor([[2 * p + i // p, 3 * p + i % p], [3 * p + i % p, 2 * p + i // p]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
        return Data(edge_index=edge_index, num_nodes=4*p), any(np.logical_and(top, bottom))

    def makedata(self, size = 1000, p = 4, vary_sizes=False):
        if p == -1:
            p = self.p
        size = size // 2
        trues = []
        falses = []
        while len(trues) < size or len(falses) < size:
            if vary_sizes:
                data, label = self.gen_graph(np.random.randint(2, p+1))
            else: 
                data, label = self.gen_graph(p)
            data = self.makefeatures(data)
            data = self.addports(data)
            if label:
                data.y = torch.tensor(1)
            else:
                data.y = torch.tensor(0)
            data.root = torch.tensor(0)

            G = to_networkx(data).to_undirected()
            print(label, nx.is_connected(G.to_undirected()), [len(c) for c in nx.connected_components(G)])
            #if nx.is_connected(G.to_undirected()) == False:
            #    continue
            #print("connected")
            if label and len(trues) < size:
                trues.append(data)
            elif not label and len(falses) < size:
                falses.append(data)
        return trues + falses

class SkipCircles(SymmetrySet):
    def __init__(self):
        super().__init__()  
        self.hidden_units = 32
        self.num_classes = 10 # num skips
        self.num_features = 1
        self.num_nodes = 41
        self.graph_class = True

    def makedata(self):
        size=self.num_nodes
        skips = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16]
        graphs = []
        for s, skip in enumerate(skips):
            edge_index = torch.tensor([[0, size-1], [size-1, 0]], dtype=torch.long)
            for i in range(size - 1):
                e = torch.tensor([[i, i+1], [i+1, i]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
            for i in range(size):
                e = torch.tensor([[i, i], [(i - skip) % size, (i + skip) % size]], dtype=torch.long)
                edge_index = torch.cat([edge_index, e], dim=-1)
            data = Data(edge_index=edge_index, num_nodes=self.num_nodes)
            data = self.makefeatures(data)
            data = self.addports(data)
            data.y = torch.tensor(s)
            graphs.append(data)

        return graphs