#! -*- coding: utf-8
import numpy as np
import torch

from .dynamic_graph import *

__all__ = ["RandomKGraph"]


class RandomKGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):

        rs = np.random.RandomState(seed)
        counts = {i: 0 for i in range(n_nodes)}
        w = np.eye(n_nodes)
        for i in range(n_nodes):
            c = counts[i]
            if c >= K:
                continue
            candidates = np.array([k for k, v in counts.items() if v < K and k != i])
            if len(candidates) < 1: continue
            candidates = candidates[np.argsort(-rs.rand(len(candidates)))][:(K-c)]
            for candidate in candidates:
                w[(i, candidate)] = 1
                w[(candidate, i)] = 1
                counts[candidate] += 1
                counts[i] += 1

        w = w / w.sum(axis=1, keepdims=True)
        w_list = [torch.tensor(w)]

        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
