import numpy as np

# Graph class
class graph:
    def __init__(self, A, node_attr, Ln, edge_attr, Le, y=None):
        '''General graph class
        Args:
            A (list): adjacency matrix
            node_attr (list or None): a list of int representing positional vector indicates label of each node (shape: 1xN)
            Ln (int or None): number of node labels
            edge_attr (dict or None): a dict indicates label of each edge ((u, v): label)
            Le (int or None): number of edge labels
            y: (float or None): objective value of given graph
        '''
        self.A = A
        self.node_attr = node_attr
        self.Ln = Ln
        self.edge_attr = edge_attr
        self.Le = Le
        self.y = y

        # fixed params calculated from graph
        self.N = len(A)    # node number
        self.E = np.sum(np.triu(np.array(A), k=1))

        if self.node_attr is not None:
            self.S = [self.node_attr.count(l) for l in range(self.Ln)]    # node label counts (shape 1xLn)
        if self.edge_attr is not None:
            for u in range(self.N):
                for v in range(self.N):
                    if (u, v) not in self.edge_attr:
                        self.edge_attr[u, v] = None

        self.calculate_path_stats()


    def calculate_path_stats(self):
        self.dis = []    # NxN matrix consisting distances between nodes
        for i in range(self.N):
            self.dis.append([])
            for j in range(self.N):
                self.dis[i].append(self.N)
                if self.A[i][j]:
                    self.dis[i][j] = 1
            self.dis[i][i] = 0
        for k in range(self.N):
            for i in range(self.N):
                for j in range(self.N):
                    self.dis[i][j] = min(
                        self.dis[i][j], self.dis[i][k] + self.dis[k][j]
                    )
        self.D = {}    # distance counts dict, keys: 0, 1, ..., N-1
        if self.node_attr is not None:
            self.P  = {}    # label matched distance counts dict, keys: (dis, l1, l2)
        for s in range(self.N):
            self.D[s] = 0
            if self.node_attr is not None:
                for l1 in range(self.Ln):
                    for l2 in range(self.Ln):
                        self.P[(s, l1, l2)] = 0
        for u in range(self.N):
            for v in range(self.N):
                if self.dis[u][v] < self.N:
                    self.D[self.dis[u][v]] += 1
                    if self.node_attr is not None:
                        self.P[(self.dis[u][v], self.node_attr[u], self.node_attr[v])] += 1