#! -*- coding: utf-8
import typing

import numpy as np
import torch

from .dynamic_graph import *

__all__ = ["KpartileRandomMatchGraph"]


class KpartileRandomMatchGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        self.rs = np.random.RandomState(seed)
        self.n_nodes = n_nodes
        self.K = K
        assert K > 0
        w_list = self.make_w_list()

        super().__init__(w_list,
                         penalty=penalty, nrepeat=nrepeat, seed=seed)
        # self.w_list = self.make_w_list

    def make_w_list(self):
        n = self.n_nodes
        indices = self.rs.permutation(np.arange(n))
        divs = np.arange(self.K+1, self.n_nodes, self.K+1)

        # w = np.zeros((n, n))
        w = np.eye(n)
        for idxs in np.array_split(indices, divs):
            for i in idxs:
                for j in idxs:
                    w[i, j] = 1
        w = w / w.sum(axis=-1, keepdims=True)
        return [torch.tensor(w)]

    def get_neighbors(self, i, idx: int = None) -> typing.Tuple[typing.Dict[int, float], typing.Dict[int, float]]:
        self.w_list = self.make_w_list()
        return super().get_neighbors(i, idx=idx)
