import torch
from dynamic_graph import *

class Torus(DynamicGraph):
    def __init__(self, p, q):
        self.n_nodes = p * q
        self.p = p
        self.q = q

        if p <=2 or q<=2:
            print("ERROR")
        
        w = torch.zeros((self.n_nodes, self.n_nodes))

        node_list_list = self.split_nodes()
        
        for sub_node_list in node_list_list:
            for i in range(len(sub_node_list)):
                i_idx = sub_node_list[i]
                j_idx = sub_node_list[(i+1) % len(sub_node_list)]
                w[i_idx, j_idx] = 1/5
                w[j_idx, i_idx] = 1/5
                w[i_idx, i_idx] = 1/5
                w[j_idx, j_idx] = 1/5
                

        node_list_list2 = self.split_nodes2()
        
        for sub_node_list in node_list_list2:
            for i in range(len(sub_node_list)):
                i_idx = sub_node_list[i]
                j_idx = sub_node_list[(i+1) % len(sub_node_list)]
                w[i_idx, j_idx] = 1/5
                w[j_idx, i_idx] = 1/5
                w[i_idx, i_idx] = 1/5
                w[j_idx, j_idx] = 1/5
                

        super().__init__([w])

    def split_nodes(self):
        node_list = list(range(self.n_nodes))
        node_list_list = [node_list[i*self.q:(i+1)*self.q] for i in range(self.p)]
        return node_list_list

    
    def split_nodes2(self):
        node_list_list = self.split_nodes()
        node_list_list2 = [[] for _ in range(self.q)]

        for i in range(self.q):
            for j in range(self.p):
                node_list_list2[i].append(node_list_list[j][i])


        return node_list_list2
        
