#! -*- coding: utf-8
import numpy as np
import torch
import sympy

from .dynamic_graph import *

__all__ = ["Grid"]


class Grid(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        
        divisors = sympy.divisors(n_nodes)
        x = divisors[len(divisors)//2]
        y = n_nodes // x
        self.width, self.height = x, y

        w_list = []
        with torch.no_grad():
            for n in range(n_nodes):
                w = torch.zeros(n_nodes)
                w[n] = 1.0
                x0, y0 = n % x, n // x
                if x0 > 0:
                    w[n-1] = 1.0
                if x0 < (x-1):
                    w[n+1] = 1.0
                if y0 > 0:
                    w[n-x] = 1.0
                if y0 < (y-1):
                    w[n+x] = 1.0
                w = w/w.sum()

                w_list.append(w)

        w_list = [torch.stack(w_list, dim=0)]
        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
