import torch
import numpy as np
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx

class PrefixSumK():
    # Creates a PrefixSum on a path - output must be sum mod 2
    # input is 2 values [value, isRoot]
    def __init__(self, k=2, inp=2):
        super().__init__()
        self.num_classes = k
        self.num_features = 2
        self.name = "PrefixSum mod K"
        self.inp = inp
    def gen_graph(self, s):
        n = len(s)
        rand_perm = np.arange(n)
        np.random.shuffle(rand_perm)
        G = nx.Graph()
        G.add_nodes_from(range(n))
        rand_perm_e = list(zip(rand_perm[:-1], rand_perm[1:]))
        G.add_edges_from(rand_perm_e)

        root = rand_perm[0]
        x = [[0.0, 0.0] for i in range(n)]
        y = [[0.0] for i in range(n)]
        x[root] = [0.0, 1.0]

        counter = 0
        for i, node in enumerate(rand_perm):
            xx = int(s[i])
            x[node][0] = s[i]
            counter = (counter + xx) % self.num_classes
            y[node] = counter

        dG = from_networkx(G)
        dG.y = torch.tensor(y)
        dG.x = torch.tensor(x)
        dG.root = torch.tensor(root)
        return dG

    def makedata(self, num_graphs=200, num_nodes=8, allow_sizes=False):
        binary_strs = []
        while len(binary_strs) < num_graphs:
            graph_size = num_nodes
            if allow_sizes:
                graph_size = np.random.randint(2, graph_size + 1)
            ss = [np.random.randint(0, self.inp) * 1.0 for _ in range(graph_size)]
            if ss not in binary_strs:
                binary_strs.append(ss)
        return [self.gen_graph(s) for s in binary_strs]
    
class Trees():
    # Creates a Tree and marks the shortest path between two nodes
    # input is one-hot [isEndpoint]
    def __init__(self):
        super().__init__()
        self.num_classes = 2
        self.num_features = 2
        self.name = "ShortestPathTrees"

    def gen_graph(self, num_nodes, num):
        nx_graph = nx.random_tree(n=num_nodes, seed=num)
        tree = from_networkx(nx_graph)
        tree.x = torch.zeros(num_nodes, 2)
        tree.y = torch.zeros(num_nodes)
        src, tar = np.random.choice(num_nodes, 2)
        tree.x[src][1] = 1
        tree.x[tar][1] = 1
        shortest_path = nx.shortest_path(nx_graph, source=src, target=tar)
        for node in shortest_path:
            tree.y[node] = 1
        for node in range(num_nodes):
            if tree.x[node][1] == 0:
                tree.x[node][0] = 1
        tree.edge_attr = torch.ones(nx_graph.number_of_edges() * 2, 1)
        tree.root = torch.tensor(src)
        return tree

    def makedata(self, num_graphs=200, num_nodes=8, allow_sizes=False):
        return [self.gen_graph(num_nodes, i) for i in range(num_graphs)]

class DistanceK():
    def __init__(self, k=2):
        super().__init__()
        self.num_features = 2
        self.num_classes = k
        self.name = "Distance"

    def gen_graph(self, num_nodes):
        g = randomgraph(num_nodes)

        origin = np.random.randint(0, num_nodes)
        queue = [(origin, 0)]
        seen = {origin}
        even = set()

        while queue:
            node, distance = queue.pop(0)
            if distance % 2 == 0:
                even.add(node)
            for nb in g.neighbors(node):
                if nb not in seen:
                    seen.add(nb)
                    queue.append((nb, distance + 1))
        data = from_networkx(g)
        data.x = torch.tensor([[1.0, 0.0] if x != origin else [0.0, 1.0] for x in range(num_nodes)])
        # data.x[origin:origin+1,:] = torch.ones(1, self.num_features).float()
        # data.x[origin] = torch.ones([0.0,1.0])

        # distances = get_localized_distances(g, origin)
        distances = nx.shortest_path_length(g, origin)
        # data.diameter = max(distances)
        # data.distances = torch.tensor(distances).unsqueeze(1)
        data.edge_attr = torch.ones(g.number_of_edges() * 2, 1)
        data.root = torch.tensor(origin)
        data.y = torch.tensor([distances[n] % self.num_classes for n in range(num_nodes)])
        # print(torch.tensor([0.0 if n in even else 1.0 for n in range(num_nodes)]))

        return data

    def makedata(self, num_graphs=200, num_nodes=8, allow_sizes=False):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]

def randomgraph(n, **args):
    g = nx.Graph()
    g.add_nodes_from(range(n))
    tree = set()
    nodes = list(range(n))
    current = np.random.choice(nodes)
    tree.add(current)
    while (len(tree) < n):
        nxt = np.random.choice(nodes)
        if not nxt in tree:
            tree.add(nxt)
            g.add_edge(current, nxt)
            g.add_edge(nxt, current)
        current = nxt
    for _ in range(n // 5):
        i, j = np.random.permutation(n)[:2]
        while g.has_edge(i, j):
            i, j = np.random.permutation(n)[:2]
        g.add_edge(i, j)
        g.add_edge(j, i)
    return g

