import networkx as nx
import numpy as np

from decentralizepy.graphs.Graph import Graph


class ExponentialTwo(Graph):
    """
    The class for generating an exponential two topology

    """

    def getnxGraph(self, size):
        """
        Taken from https://github.com/devos50/decentralized-learning/blob/15dcf86676b0de91ef055a2080318e1462828e28/simulations/dl/__init__.py#L7

        """
        assert size > 0
        x = np.array([1.0 if i & (i - 1) == 0 else 0 for i in range(size)])
        x /= x.sum()
        topo = np.empty((size, size))
        for i in range(size):
            topo[i] = np.roll(x, i)
        G = nx.from_numpy_array(topo, create_using=nx.DiGraph)
        return G

    def __init__(self, n_procs, *args, **kwargs):
        """
        Constructor. Generates an exponential two graph topology

        Parameters
        ----------
        n_procs : int
            total number of nodes in the graph

        """
        super().__init__(n_procs)
        assert n_procs > 0, "Number of processes must be positive"
        G = self.getnxGraph(n_procs)
        adj = G.adjacency()
        for i, l in adj:
            self.adj_list[i] = set()  # new set
            for k in l:
                if k != i:
                    self.adj_list[i].add(k)
        # if not nx.is_connected(G):
        #     self.connect_graph()
