# decide if there are more 0s or 1s in the tree (i.e. majority) setup as graph classification
import torch
import numpy as np
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx

class MajorityTree():
    def __init__(self):
        super().__init__()
        self.num_features = 1
        self.num_classes = 2
        self.name = "GraphMode"

    def gen_graph(self, num_nodes):
        g = nx.random_tree(num_nodes)
        data = from_networkx(g)
        data.x = torch.randint(0, 2, (num_nodes, 1)).float() #Exclusive range [low, high)
        while(data.x.sum() == num_nodes/2): #Ensuring no draws
            data.x = torch.randint(0, 2, (num_nodes, 1)).float()
        data.y = torch.ones(1) if data.x.sum() > num_nodes / 2 else torch.zeros(1)
        data.root = 0
        data.y = data.y.to(torch.long)
        return data

    def makedata(self, num_graphs=200, num_nodes=8):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]


class MajoritySubTree():
    def __init__(self):
        super().__init__()
        self.num_features = 2
        self.num_classes = 2
        self.name = "SubTreeMode"

    def dfs(self, graph, vis, node, subTreeSize, x, ans):
        vis[node] = True
        subTreeSize[node] = 1
        subTreeOnes = x[node][0]
        for i in graph.neighbors(node):
            if not vis[i]:
                subTreeOnes += self.dfs(graph, vis, i, subTreeSize, x, ans)
                subTreeSize[node] += subTreeSize[i]
        ans[node] = 1 if subTreeOnes > subTreeSize[node] / 2 else 0
        return subTreeOnes

    #Draws are also marked label 0
    def gen_graph(self, num_nodes):
        g = nx.random_tree(num_nodes)
        data = from_networkx(g)
        data.x = torch.randint(0, 2, (num_nodes, 2)).float() #Exclusive range [low, high)
        data.x[:, 1] = torch.zeros(num_nodes)
        data.x[0, 1] = 1
        ans = torch.zeros(num_nodes)
        self.dfs(g, [False] * num_nodes, 0, [0] * num_nodes, data.x.clone(), ans)
        data.root_index = 0
        data.y = ans.to(torch.long)
        return data

    def makedata(self, num_graphs=200, num_nodes=8):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]
    

# Find top k nodes in a tree
class TopK():
    def __init__(self, k=2):
        super().__init__()
        self.num_features = 1
        self.num_classes = 2
        self.k = k
        self.name = "topK"

    def gen_graph(self, num_nodes):
        g = nx.random_tree(num_nodes)
        g = from_networkx(g)
        g.x = torch.FloatTensor(num_nodes).uniform_(0, 1)
        indices = torch.topk(g.x, self.k)[1]
        g.y = torch.zeros(num_nodes)
        g.y[indices] = 1
        g.x = g.x.unsqueeze(1)
        g.root_index = 0
        return g
    def makedata(self, num_graphs=200, num_nodes=12):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]
    



class BroadcastK():
    def __init__(self,k = 2):
        super().__init__()
        self.k = k
        self.num_features = 1
        self.name = "Broadcast"

    def gen_graph(self, num_nodes):
        G = nx.random_tree(num_nodes)
        leaves = torch.tensor([x for x in G.nodes() if G.degree[x] == 1 and x != 0])
        G = from_networkx(G)
        #Set G.x as random tensor
        G.x = torch.ones(num_nodes,2)
        G.x[:, 1] = torch.zeros(num_nodes)
        G.x[0, 1] = 1
        G.x[0][0] = torch.randint(0, self.k, (1, 1)).float()
        #Declare G.y torch tensor with value z
        G.y = torch.zeros((num_nodes,))
        G.y[leaves] = G.x[0][0]
        #G.y = G.y.unsqueeze(1)
        print(G.x, G.y)
        #G.x = G.x.unsqueeze(1)
        G.leaves_index = leaves
        return G

    def makedata(self, num_graphs=200, num_nodes=12):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]

class Broadcast():
    def __init__(self):
        super().__init__()
        self.num_features = 1
        self.name = "Broadcast"

    def gen_graph(self, num_nodes):
        G = nx.random_tree(num_nodes)
        leaves = torch.tensor([x for x in G.nodes() if G.degree[x] == 1 and x != 0])
        G = from_networkx(G)
        #Set G.x as random tensor
        G.x = torch.ones(num_nodes,2)
        G.x[:, 1] = torch.zeros(num_nodes)
        G.x[0, 1] = 1
        G.x[0][0] = torch.FloatTensor(1).uniform_(0, 1)
        #Declare G.y torch tensor with value z
        G.y = torch.zeros((num_nodes,))
        G.y[leaves] = G.x[0][0]
        #G.y = G.y.unsqueeze(1)
        print(G.x, G.y)
        #G.x = G.x.unsqueeze(1)
        G.leaves_index = leaves
        return G

    def makedata(self, num_graphs=200, num_nodes=12):
        return [self.gen_graph(num_nodes) for _ in range(num_graphs)]