from model.shiftgcnutils import tools


n_node = 24
self_link = [(i, i) for i in range(n_node)]
inward = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
        (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
outward = [(j, i) for (i, j) in inward]
neighbor = inward + outward


class HumanAct12Graph:
    def __init__(self, labeling_mode='spatial'):
        self.A = self.get_adjacency_matrix(labeling_mode)
        self.n_node = n_node
        self.self_link = self_link
        self.inward = inward
        self.outward = outward
        self.neighbor = neighbor

    def get_adjacency_matrix(self, labeling_mode=None):
        if labeling_mode is None:
            return self.A
        if labeling_mode == 'spatial':
            A = tools.get_spatial_graph(n_node, self_link, inward, outward)
        else:
            raise ValueError()
        return A


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import os

    A = HumanAct12Graph('spatial').get_adjacency_matrix()
    for i in A:
        plt.imshow(i, cmap='gray')
        plt.show()
    print(A)
