#! -*- coding: utf-8
import typing

import numpy as np
import sympy
import sympy.ntheory
import torch

from .dynamic_graph import DynamicGraph
from .simple_base_graph import SimpleBaseGraph

__all__ = ["ExpBaseGraph"]


class ExpBaseGraph(DynamicGraph):
    def __init__(self, n_nodes: int, K: int = 0,
                 inner_edges: bool = True,
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        assert K > 0
        self.seed = seed
        self.K = K
        self.n_nodes = n_nodes
        self.inner_edges = inner_edges

        super().__init__(self.construct(),
                         penalty=penalty, nrepeat=nrepeat, seed=seed)

    def construct(self) -> typing.List[torch.Tensor]:
        K, N = self.K + 1, self.n_nodes

        out_factor, in_factors = 1, []
        for f, c in [(f, c) for f, c in sympy.ntheory.factorint(N).items() if f > K]:
            out_factor *= np.prod([f]*c) 

        N //= out_factor
        while N > 1:
            divs = [d for d in sympy.divisors(N) if d <= K]
            f = np.max(divs)
            N //= f
            in_factors.append(f)

        N, w_list = self.n_nodes, []

        with torch.no_grad():
            if out_factor > 1:
                ngraph = N//out_factor
                wlist = SimpleBaseGraph(out_factor, max_degree=K, seed=self.seed,
                                        inner_edges=self.inner_edges).w_list
                for w0 in wlist:
                    w = torch.zeros((N, N))
                    for i in range(0, ngraph):
                        w[i:N:ngraph, i:N:ngraph] = w0.clone()
                    w_list.append(w)

            b = out_factor
            for c in in_factors:
                m = N // c // b
                print(f"N={N}, b={b}, c={c}, m={m}")
                b = N // m

                w0 = [*([1.0] + [0.0]*(m-1))] * c
                w0 = np.array(w0 + [0.0] * (N-len(w0)))
                w0 = (w0 / w0.sum()).tolist()

                w_list.append(torch.tensor(
                    np.array([w0[-i:] + w0[:-i]
                              for i in range(N)]).astype(np.float32)))

        return w_list
