#! -*- coding: utf-8
import typing

import numpy as np
import torch

from .dynamic_graph import *

__all__ = ["BipartileRandomMatchGraph"]


class BipartileRandomMatchGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):

        self.rs = np.random.RandomState(seed)
        self.n_nodes = n_nodes
        w_list = self.make_w_list()

        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
        # self.w_list = self.make_w_list

    def make_w_list(self):
        n = self.n_nodes
        indices = self.rs.permutation(np.arange(n))

        # w = np.zeros((n, n))
        w = np.eye(n)
        for i, j in zip(indices, indices[1:]):
            w[i, j] = 1
            w[j, i] = 1
        w = w / w.sum(axis=-1, keepdims=True)
        return [torch.tensor(w)]


    def get_neighbors(self, i, idx: int = None) -> typing.Tuple[typing.Dict[int, float], typing.Dict[int, float]]:
        self.w_list = self.make_w_list()
        return super().get_neighbors(i, idx=idx)
