#! -*- coding: utf-8
import copy
import math

import numpy as np
import sympy
import torch

__all__ = ["SimpleExpBaseGraph"]

from .dynamic_graph import DynamicGraph
from .k_peer_exponential_graph import KPeerExponentialGraph


class SimpleExpBaseGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 1, seed: int = 0, inner_edges: bool = True,
                 penalty: str = "no", nrepeat: int = 5, ):
        self.state = np.random.RandomState(seed)
        self.inner_edges = inner_edges
        self.K = K
        self.n_nodes = n_nodes

        super().__init__(self.construct(),
                         penalty=penalty, nrepeat=nrepeat, seed=seed)

    def construct(self):
        node_list_list, n_nodes_list = self.split_nodes()
        node_list_list_list = self.split_nodes2(node_list_list)
        L = len(node_list_list)

        if self.n_nodes == 1:
            return [torch.eye(1)]
        elif max(list(sympy.factorint(self.n_nodes))) <= self.K + 1:
            return KPeerExponentialGraph(self.n_nodes, K=self.K).w_list

        # construct k-peer Exponential Graph
        hyperhyper_cubes = [KPeerExponentialGraph(len(node_list_list[i]), K=self.K)
                            for i in range(L)]
        hyperhyper_cubes2 = [KPeerExponentialGraph(len(node_list_list_list[i][0]), K=self.K)
                             for i in range(L)]
        max_length_of_hyper = len(hyperhyper_cubes[0].w_list)

        b = torch.zeros(L)
        true_b = torch.tensor([len(hyperhyper_cube.w_list)
                              for hyperhyper_cube in hyperhyper_cubes2])

        w_list = []
        m = -1
        while True:
            m += 1
            w = torch.zeros((self.n_nodes, self.n_nodes))
            isolated_nodes = None
            all_isolated_nodes = None

            for l in reversed(range(L)):

                if m < max_length_of_hyper:
                    length = len(hyperhyper_cubes[l].w_list)
                    w += self.extend(hyperhyper_cubes[l].w_list[m %
                                     length], node_list_list[l])

                elif m < max_length_of_hyper + l:
                    if isolated_nodes is None:
                        isolated_nodes = copy.deepcopy(
                            node_list_list_list[m - max_length_of_hyper])
                        all_isolated_nodes = [
                            node for nodes in isolated_nodes for node in nodes]

                    for i in node_list_list[l]:
                        a_l = len(isolated_nodes)

                        for k in range(a_l):
                            j = isolated_nodes[k].pop(-1)
                            all_isolated_nodes.remove(j)
                            w[i, j] = (n_nodes_list[m - max_length_of_hyper] /
                                       sum(n_nodes_list[m - max_length_of_hyper:]) / a_l)
                            w[j, i] = (n_nodes_list[m - max_length_of_hyper] /
                                       sum(n_nodes_list[m - max_length_of_hyper:]) / a_l)

                            w[j, j] = 1 - w[i, j]
                        w[i, i] = (1 - n_nodes_list[m - max_length_of_hyper] /
                                   sum(n_nodes_list[m - max_length_of_hyper:]))

                elif m == max_length_of_hyper + l and l != L-1:
                    while len(all_isolated_nodes) > 1 and self.inner_edges:
                        sampled_nodes = all_isolated_nodes[:min(
                            self.K+1, len(all_isolated_nodes))]

                        for node_id in sampled_nodes:
                            all_isolated_nodes.remove(node_id)

                        for i in sampled_nodes:
                            for j in sampled_nodes:
                                w[i, j] = 1 / len(sampled_nodes)
                                w[j, i] = 1 / len(sampled_nodes)
                                w[i, i] = 1 / len(sampled_nodes)
                                w[j, j] = 1 / len(sampled_nodes)

                else:
                    if n_nodes_list[l] < self.K+1:
                        length = len(hyperhyper_cubes[l].w_list)
                        w += self.extend(hyperhyper_cubes[l].w_list[int(b[l] % length)],
                                         node_list_list[l])
                    else:
                        a_l = len(node_list_list_list[l])

                        for k in range(a_l):
                            length = len(hyperhyper_cubes2[l].w_list)
                            w += self.extend(hyperhyper_cubes2[l].w_list[int(b[l] % length)],
                                             node_list_list_list[l][k])

                    b[l] += 1

            # add self-loop
            for i in range(self.n_nodes):
                if w[i, i] == 0:
                    w[i, i] = 1.0
            w_list.append(w)

            # if (b >= true_b).all():
            #    break
            if b[0] == len(hyperhyper_cubes2[0].w_list):
                break

        return w_list

    def extend(self, w, node_list):
        new_w = torch.zeros((self.n_nodes, self.n_nodes))
        for i in range(len(node_list)):
            for j in range(len(node_list)):
                new_w[node_list[i], node_list[j]] = w[i, j]
        return new_w

    def split_nodes(self):
        factor = (self.K + 1)**int(math.log(self.n_nodes,
                                            self.K+1))
        n_nodes_list = []

        while sum(n_nodes_list) != self.n_nodes:

            rest = self.n_nodes - sum(n_nodes_list)

            if rest >= factor:
                n_nodes_list.append((rest // factor) * factor)
            factor = int(factor/(self.K + 1))
        node_list = list(range(self.n_nodes))
        node_list_list = []
        for i in range(len(n_nodes_list)):
            node_list_list.append(
                node_list[sum(n_nodes_list[:i]):sum(n_nodes_list[:i+1])])

        return node_list_list, n_nodes_list

    def split_nodes2(self, node_list_list):
        """
        len(node_list) can be written as a_l * (K + 1)^{p_l} where al \\in \\{1, 2, \\cdots, k\\}.
        """
        node_list_list_list = []

        for node_list in node_list_list:
            n_nodes = len(node_list)
            power = math.gcd(n_nodes, (self.K+1) ** int(math.log(n_nodes,
                                                                 self.K+1)))
            rest = int(n_nodes / power)

            node_list_list_list.append([])
            for i in range(rest):
                node_list_list_list[-1].append(node_list[i*power:(i+1)*power])

        return node_list_list_list
