#! -*- coding: utf-8
import numpy as np
import torch
import sympy

from .dynamic_graph import *

__all__ = ["Torus"]


class Torus(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):

        divisors = sympy.divisors(n_nodes)
        x = divisors[len(divisors)//2]
        y = n_nodes // x
        self.width, self.height = x, y

        w_list = []
        with torch.no_grad():
            for n in range(n_nodes):
                w = torch.zeros(n_nodes)
                w[n] = 1.0 
                x0, y0 = n % x, n // x
                w[((x0+1) % x + y0*x)] = 1.0
                w[((x0-1) % x + y0*x)] = 1.0
                w[(x0+((y0+1) % y)*x)] = 1.0
                w[(x0+((y0-1) % y)*x)] = 1.0

                w = w/w.sum()

                w_list.append(w)

        w_list = [torch.stack(w_list, dim=0)]
        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
