#! -*- coding: utf-8
import math

import torch

from .dynamic_graph import *

__all__ = ["GeneralizedKPeerExponentialGraph"]


class GeneralizedKPeerExponentialGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int=0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        assert K > 0
        self.S = int(math.ceil(math.log(n_nodes, K+1)))

        with torch.no_grad():
            weight = 1.0 / self.S
            w_list = []
            for step in range(self.S):
                w = torch.zeros((n_nodes, n_nodes))
                for i, weights in enumerate(w):
                    for idx in [(i + idx * (self.S**step)) % n_nodes for idx in range(self.S)]:
                        weights[idx] = weight

                w_list.append(w)

        super().__init__(w_list, 
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
