#! -*- coding: utf-8
import numpy as np
import torch

from .dynamic_graph import *

__all__ = ["HalfRandomGraph"]


class HalfRandomGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):

        rs = np.random.RandomState(seed)
        w = np.triu(rs.rand(n_nodes, n_nodes), k=1)
        w = w + np.eye(n_nodes) + w.T
        w = w >= 0.5
        w = w / w.sum(axis=1, keepdims=True)
        w_list = [torch.tensor(w)]

        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
