import numpy as np
import scipy as sp
import networkx as nx


class Topology:
    """
    Topology class: Encapsulates the topology of the graph

    Attributes:
        interaction_matrix:(nb_nodes,nb_edges) - "Encoding -1 in (i, (i,j)) and 1 in (j,(i,j)), 0 elsewhere".
        Be careful: (undirected) edges are represented 2 times each (in a directed way).
        for i,j honest. (j,i) = i * nb_honest + j
        for i honest j byzantine: (j,i) = nb_honest**2 + nb_byzantine_neighbors * i + j_i
    """
    name = "Topology"

    def __init__(self, nb_honest, nb_byzantine_neighbors=None):
        self.nb_honest = nb_honest
        self.nb_byz = 0
        
        self.interaction_matrix = None
        self.interaction_matrix_honest = None
        self.adjacency_matrix = None
        self.adjacency_matrix_honest = None
        self.laplacian_matrix = None
        self.laplacian_matrix_honest = None


        self._delta_infinite = None
        self.nb_byzantine_neighbors = None
        

        if nb_byzantine_neighbors is not None:
            self.add_byzantine_neighbors(nb_byzantine_neighbors)

    def add_byzantine_neighbors(self, nb_byzantine_neighbors: int):
        self.nb_byzantine_neighbors = nb_byzantine_neighbors
        self.nb_byz = self.nb_honest * nb_byzantine_neighbors

        self.interaction_matrix = np.zeros(
            (self.nb_honest + self.nb_byz, self.nb_honest ** 2 + 2 * self.nb_byz))

        self.interaction_matrix[:self.nb_honest, :self.nb_honest ** 2] = self.interaction_matrix_honest

        for h in range(self.nb_honest):
            for b in range(nb_byzantine_neighbors):
                self.interaction_matrix[h, self.nb_honest ** 2 + nb_byzantine_neighbors * h + b] = 1
                self.interaction_matrix[self.nb_honest + nb_byzantine_neighbors * h + b,
                                        self.nb_honest ** 2 + nb_byzantine_neighbors * h + b] = -1

                # For consistency, we add directed edges for byzantines too
                self.interaction_matrix[h, self.nb_honest ** 2 + self.nb_byz + nb_byzantine_neighbors * h + b] = -1
                self.interaction_matrix[self.nb_honest + nb_byzantine_neighbors * h + b,
                                        self.nb_honest ** 2 + self.nb_byz + nb_byzantine_neighbors * h + b] = 1

        self._delta_infinite = None
        self.adjacency_matrix = np.maximum(self.interaction_matrix, 0)

    def add_honest_edge(self, i, j):
        self.interaction_matrix_honest[i, i * self.nb_honest + j] = 1
        self.interaction_matrix_honest[j, i * self.nb_honest + j] = -1

        self.interaction_matrix_honest[i, j * self.nb_honest + i] = -1
        self.interaction_matrix_honest[j, j * self.nb_honest + i] = 1

    def delta_infinite(self):
        if self._delta_infinite is None:
            self._compute_delta_infinite()
        return self._delta_infinite

    def laplacian(self):
        if self.laplacian_matrix == None:
            self.laplacian_matrix = self.adjacency_matrix @ self.interaction_matrix.T
        return self.laplacian_matrix

    def laplacian_honest(self):
        if self.laplacian_matrix_honest is None:
            self.laplacian_matrix_honest = self.adjacency_matrix_honest @ self.interaction_matrix_honest.T
        return self.laplacian_matrix_honest

    def _compute_delta_infinite(self):
        self.interaction_byz = self.interaction_matrix[:self.nb_honest, self.nb_honest ** 2:]

        matrix_delta = self.interaction_matrix_honest.T @ np.linalg.pinv(
            self.interaction_matrix_honest @ self.interaction_matrix_honest.T) @ self.interaction_byz

        self._delta_infinite = np.linalg.norm(matrix_delta, ord=np.infty)

    def nb_honest_edges(self):
        return np.sum(self.adjacency_matrix_honest)

    def nb_honest_neighbors_min(self):
        res = self.nb_honest
        for i in range(self.nb_honest):
            res = min(res, np.sum(self.adjacency_matrix_honest[i, :]))
        return res

    def nb_byzantine_edges(self):
        return self.nb_byz

    def auto_step_size(self):
        return 1 / np.linalg.norm(self.laplacian(), ord=2)

    def spectral_gap(self):
        eig = np.linalg.eigvalsh(self.laplacian_honest())
        return eig[1] / eig[-1]
    
    def _sparsify(self):
        """store all matrix as sparse matrix, might break some stuffs, such as slicing"""
        self.interaction_matrix = sp.sparse(self.interaction_matrix)
        self.interaction_matrix_honest = sp.sparse(self.interaction_matrix_honest)
        self.adjacency_matrix = sp.sparse(self.adjacency_matrix)
        self.adjacency_matrix_honest = sp.sparse(self.adjacency_matrix_honest)
        self.laplacian_matrix = sp.sparse(self.laplacian_matrix)
        self.laplacian_honest_matrix = sp.sparse(self.laplacian_honest_matrix)



class Clique(Topology):
    name = "Clique"

    def __init__(self, nb_honest):
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))

        # We encode here all directed edges. Note that the graph is actually undirected
        for i in range(self.nb_honest - 1):
            for j in range(i + 1, self.nb_honest):
                self.add_honest_edge(i, j)

        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest

class Dumbell(Topology):
    name = "Dumbell"
    """
    topology consiting of two fully-connected graphs of the same number of nodes. Each node in both clique
    is connected with exactly one node in the other one.
    """

    def __init__(self, nb_honest):
        if nb_honest%2!=0:
            nb_honest = nb_honest + nb_honest%2
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))

        # We encode here all directed edges. Note that the graph is actually undirected
        for i in range(self.nb_honest//2 - 1):
            for j in range(i + 1, self.nb_honest//2):
                self.add_honest_edge(i, j)
        for i in range(self.nb_honest//2, self.nb_honest - 1):
            for j in range(i + 1, self.nb_honest):
                self.add_honest_edge(i, j)

        for k in range(3):
            for i in range(self.nb_honest//2):
                self.add_honest_edge(i, (i + k)% (self.nb_honest//2) + self.nb_honest//2)

        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest



class UltraTorus(Topology):
    name = "Torus with mutliple neighbors"
    """
    A circular graph with one node in the middle communicating with everyone
    """

    def __init__(self, nb_honest, connective_range=1, opposed_neighbors=0):
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))

        for i in range(self.nb_honest):
            for dist in range(0,connective_range):
                d = dist * 2 + 1
                self.add_honest_edge(i, (i + d) % self.nb_honest)
                self.add_honest_edge(i, (i - d) % self.nb_honest)

        for i in range(0, self.nb_honest):
            for k in range(1, opposed_neighbors+1):
                self.add_honest_edge(i,
                                     (i + k * self.nb_honest//opposed_neighbors) % self.nb_honest)


        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest



class CircularWithPrincipal(Topology):
    name = "Circular with Principal graph"
    """
    A circular graph with one node in the middle communicating with everyone
    """

    def __init__(self, nb_honest):
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))

        for i in range(self.nb_honest - 1):
            if i == self.nb_honest - 2:
                j = 0
            else:
                j = i + 1
            self.add_honest_edge(i, j)
            self.add_honest_edge(i, self.nb_honest - 1)

        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest


class TorusWithPrincipal(Topology):
    name = "Torus with Principal graph"
    """
    A circular graph with one node in the middle communicating with everyone
    """

    def __init__(self, nb_honest):
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))

        for i in range(self.nb_honest - 1):  # first circle outside
            if i == self.nb_honest - 2:
                j = 0
            else:
                j = i + 1
            self.add_honest_edge(i, j)
            self.add_honest_edge(i, self.nb_honest - 1)

        for i in range(self.nb_honest - 1):  # circle inside
            if i == self.nb_honest - 3:
                j = 0
            elif i == self.nb_honest - 2:
                j = 1
            else:
                j = i + 2
            self.add_honest_edge(i, j)
            self.add_honest_edge(i, self.nb_honest - 1)

        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest


class ErdosRenyi(Topology):
    name = "ErdosRenyi"

    """
    Erdos Renyi graph, with probability p to have any edge in the graph
    """
    def __init__(self, nb_honest, p, seed=0):
        super().__init__(nb_honest)
        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))
        np.random.seed(seed)
        for i in range(self.nb_honest - 1):
            for j in range(i + 1, self.nb_honest):
                if np.random.random() < p:
                    self.add_honest_edge(i, j)

        self.interaction_matrix = self.interaction_matrix_honest
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix)
        self.adjacency_matrix = self.adjacency_matrix_honest




class networkxTopology(Topology):
    name="Networks"

    """
    Allows to import any topology using networkx library
    """

    def __init__(self, graph):
        super().__init__(graph.order())
        self.graph_honest = graph

        self.interaction_matrix_honest = np.zeros((self.nb_honest, self.nb_honest ** 2))
        for edge in self.graph_honest.edges:
            self.add_honest_edge(edge[0], edge[1])
    
        self.adjacency_matrix_honest = np.maximum(0, self.interaction_matrix_honest)
        self.adjacency_matrix = self.adjacency_matrix_honest




class TwoWorlds(networkxTopology):
    name = "TwoWorlds"

    """
    TwoWorlds graph, two clique of size size//2 and k connection between each node of one clique to nodes in 
    the other clique (in a circulant manner)
    """
    def __init__(self, nb_honest, k):
        if nb_honest%2 != 0:
            raise ValueError("nb_honest must be even for TwoWorlds topology")
        elif k > nb_honest//2:
            raise ValueError(f" Two worlds requires k <= nb_honest//2, here k:{k} and nb_honest:{nb_honest}")
        
        c1 = nx.complete_graph(nb_honest//2)
        c2 = nx.complete_graph(nb_honest -nb_honest//2)
        c2 = nx.relabel_nodes(c2, {i:i+nb_honest//2 for i in range(nb_honest-nb_honest//2)}, copy=False)
        net = nx.union(c1,c2)

        for i in range(nb_honest//2):
            for k in range(int(k)):
                net.add_edge(i, (i+ k)%(nb_honest//2) + nb_honest//2)

        super().__init__(net)
