#! -*- coding: utf-8
import torch

from .dynamic_graph import *

__all__ = ["Ring", "Ring2"]


class Ring(DynamicGraph):
    def __init__(self, n_nodes,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):

        if n_nodes >= 3:
            w = torch.zeros((n_nodes, n_nodes))
            for i in range(n_nodes):
                w[i, i] = 1/3
                w[i, (i+1) % n_nodes] = 1/3
                w[i, (i-1) % n_nodes] = 1/3
            w = w / w.sum(dim=-1, keepdim=True)
        else: # full connection.
            w = torch.ones((n_nodes, n_nodes)) / n_nodes

        super().__init__([w],
                         penalty=penalty, nrepeat=nrepeat, seed=seed)


class Ring2(DynamicGraph):
    def __init__(self, n_nodes,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        w = torch.eye(n_nodes)
        with torch.no_grad():
            right = torch.concat([w[:, -1:], w[:, :-1]])
            left = torch.concat([w[:, 1:], w[:, :1]])
            w = w + right + left
            w = w / w.sum(dim=-1, keepdim=True)

        super().__init__([w],
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
