#! -*- coding: utf-8
import math

import torch

from .dynamic_graph import *

__all__ = ["ExponentialGraph"]


class ExponentialGraph(DynamicGraph):
    def __init__(self, n_nodes,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        w = torch.zeros((n_nodes, n_nodes))

        n_neighbors = int(math.log2(n_nodes-1))
        for i in range(n_nodes):
            w[i, i] = 1 / (math.ceil(math.log2(n_nodes)) + 1)

            for j in range(n_neighbors+1):
                w[i, (i+2**j) % n_nodes] = 1 / \
                    (math.ceil(math.log2(n_nodes)) + 1)

        super().__init__([w], 
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
