import numpy as np
import networkx as nx
import torch

class UndirectedTopology:

    def __init__(self, n_clients:int = None):
        self.dynamic = False
        self.W = None
        if n_clients != None:
            self.n_clients = n_clients
            self.adj_list = [set() for i in range(self.n_clients)]
        
    def get_all_nodes(self):
        return [i for i in range(self.n_clients)]
    
    def get_neighbours(self, node: int):
        """
        Returns the neighbours of a given node
        """
        return self.adj_list[node]
    
    def get_degree(self, node: int):
        """
        Returns the degree of a given node
        """
        return len(self.adj_list[node])
    
    def __insert_adj__(self, node: int, neighbours: list):
        '''
        Inserts `neighbours` into the adjacency list of `node`
        '''
        self.adj_list[node].update(neighbours)

    def  __insert_edge__(self, x: int, y: int):
        '''
        Inserts edge `x -> y` into the graph
        '''
        self.adj_list[x].add(y)
        self.adj_list[y].add(x)

    def write_graph_to_file(self, file: str, type: str = "edges"):
        '''
        Writes the graph to a given file
        '''
        with open(file, "w") as f:
            if type == "edges":
                for i in range(self.n_clients):
                    for j in self.adj_list[i]:
                        f.write(f"{i} {j}\n")
            elif type == "adjacency":
                for i in range(self.n_clients):
                    f.write(f"{' '.join(map(str, self.adj_list[i]))}\n")


    def read_graph_from_file(self, file: str, type: str = "edges", force_connect: bool = False):
        '''
        Reads the graph from a given file
        '''
        with open(file, "r") as f:
            if type == "edges":
                for line in f:
                    x, y = map(int, line.strip().split())
                    self.__insert_edge__(x, y)
            elif type == "adjacency":
                for i, line in enumerate(f):
                    neighbours = list(map(int, line.strip().split()))
                    self.__insert_adj__(i, neighbours)
    
    def to_networkx(self):
        '''
        Converts the graph to a networkx graph
        '''
        G = nx.Graph()
        for i in range(self.n_clients):
            for j in self.adj_list[i]:
                G.add_edge(i, j)
        return G
    
    def __create_topology(self):
        raise NotImplementedError("This method should be overridden by subclasses")
    
    def mixing_matrix(self)->torch.Tensor:
        '''
        Returns the mixing matrix of the graph
        '''
        if self.W is not None and self.dynamic == False:
            return self.W
        W = np.zeros((self.n_clients, self.n_clients))
        for i in range(self.n_clients):
            for j in self.adj_list[i]:
                weight = 1.0 / (1 + max(self.get_degree(i), self.get_degree(j)))
                W[i, j] = weight
            W[i, i] = 1 - np.sum(W[i, :])

        return torch.tensor(W, dtype=torch.float32)
    