#! -*- coding: utf-8
import typing

# import matplotlib.pyplot as plt
# import networkx as nx
import numpy as np
import torch

__all__ = ["DynamicGraph"]


class DynamicGraph():
    def __init__(self, w_list: typing.List[torch.Tensor],
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        """
        Parameter
        --------
        w_list (list of torch.tensor):
            list of mixing matrix
        """
        self.w_list: typing.List[torch.Tensor] = w_list
        self.n_nodes: int = w_list[0].size()[0]
        self.length: int = len(w_list)
        self.itr: int = 0

        # graph penalty settings
        assert penalty in ["no", "random", "repeat", "random_repeat"]
        self.penalty = penalty
        self.nrepeat = nrepeat

        self.rs = np.random.RandomState(seed)

        if self.penalty in ["repeat", "random_repeat"]:
            self.length = nrepeat

        self.curr_graph_idx = -1
        self.make_indices()
        self.curr_graph_idx = -1

    def __len__(self) -> int: return self.length

    def make_indices(self):
        if self.penalty == "no":
            self.indices = list(range(len(self.w_list)))
        elif self.penalty == "random":
            self.indices = self.rs.permutation(
                list(range(len(self.w_list)))).tolist()
        elif self.penalty == "repeat":
            self.curr_graph_idx = (self.curr_graph_idx + 1) % len(self.w_list)
            self.indices = [self.curr_graph_idx] * self.nrepeat
        elif self.penalty == "random_repeat":
            self.curr_graph_idx = (self.curr_graph_idx + 1) % len(self.w_list)
            if self.curr_graph_idx % self.nrepeat == 0:
                self.graph_order = self.rs.permutation(
                    list(range(len(self.w_list))))
            self.indices = [
                int(self.graph_order[self.curr_graph_idx])] * self.nrepeat
        else:
            raise ValueError(f"Unknown penalty: {self.penalty}")

    def get_weights(self, idx: int = None) -> torch.Tensor:
        idx = self.indices[(self.itr if idx is None else idx) %
                           len(self.indices)]
        # w = self.w_list[(self.itr if idx is None else idx) % self.length]

        return self.w_list[idx]

    def get_in_neighbors(self, i, idx: int = None) -> typing.Dict[int, float]:
        """
        Parameter
        ----------
        i (int):
            a node index
        idx (int, optional): 
            neighbors's index
        Return
        ----------
            dictionary of (neighbors's index: weight of the edge (i,j))
        """
        # w = self.w_list[(self.itr if idx is None else idx) % self.length]
        w = self.get_weights(idx=idx)

        return {idx.item(): w[idx, i].item() for idx in torch.nonzero(w[:, i])}

    def get_out_neighbors(self, i, idx: int = None) -> typing.Dict[int, float]:
        """
        Parameter
        ----------
        i (int):
            a node index
        idx (int, optional): 
            neighbors's index
        Return
        ----------
            dictionary of (neighbors's index: weight of the edge (i,j))
        """
        # w = self.w_list[(self.itr if idx is None else idx) % self.length]
        w = self.get_weights(idx=idx)

        return {idx.item(): w[i, idx].item() for idx in torch.nonzero(w[i])}

    def get_neighbors(self, i, idx: int = None) -> typing.Tuple[typing.Dict[int, float], typing.Dict[int, float]]:
        curr_graph_idx = (idx if isinstance(idx, int)
                          else self.itr) % len(self.indices)
        if curr_graph_idx == 0:
            self.make_indices()

        in_neighbors = self.get_in_neighbors(i, idx=idx)
        out_neighbors = self.get_out_neighbors(i, idx=idx)

        if idx is None:
            self.itr += 1

        return in_neighbors, out_neighbors
