#! -*- coding: utf-8
import math

import torch

from .dynamic_graph import *

__all__ = ["SparseExpGraph"]


class SparseExpGraph(DynamicGraph):
    def __init__(self, n_nodes,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        w = torch.zeros((n_nodes, n_nodes))

        n_neighbors = int(math.log2(n_nodes-1))+1
        K = 3
        for i in range(n_nodes):
            w[i, i] = 1 / K

            for j in [0, n_neighbors-2]:
                w[i, (i+2**j) % n_nodes] = 1 / K

        super().__init__([w],
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
