from typing import Iterable
from utils.topology import *
from utils.speical_topology import *


def adjacency_from_edge_list(n: int,
                             edges: Iterable[Tuple[int, int]],
                             undirected: bool = True,
                             self_loops: bool = False) -> np.ndarray:
    A = np.zeros((n, n), dtype=np.int64)
    for u, v in edges:
        if u == v and not self_loops:
            continue
        A[u, v] = 1
        if undirected:
            A[v, u] = 1
    if self_loops:
        np.fill_diagonal(A, 1)
    else:
        np.fill_diagonal(A, 0)
    return A


def build_weight_matrix(adjacency: np.ndarray,
                        lambdas: Optional[Iterable[float]] = None, epsilon: float = 0.5) -> np.ndarray:
    A = np.array(adjacency, dtype=np.int64, copy=True)
    n = A.shape[0]
    assert A.shape == (n, n), "adjacency must be a square matrix"

    deg = A.sum(axis=1).astype(np.int64)
    if lambdas is None:
        lam = np.ones(n, dtype=np.float64)
    else:
        lam = np.asarray(list(lambdas), dtype=np.float64)
        assert lam.shape == (n,), "lambdas length error"

    W = np.zeros((n, n), dtype=np.float64)

    for i in range(n):
        d_i = int(deg[i])
        if d_i == 0:
            W[i, i] = 1.0
            continue

        denom_i = d_i / (1 - epsilon)
        for j in range(n):
            if i == j:
                continue
            if A[i, j] == 0:
                W[i, j] = 0.0
                continue

            d_j = int(deg[j])
            if d_j == 0:
                ratio_term = 0.0
            else:
                lam_i = lam[i]
                lam_j = lam[j]
                if lam_i == 0.0:
                    ratio = np.inf
                else:
                    ratio = (lam_j * d_i) / (lam_i * d_j)
                ratio_term = min(1.0, max(0.0, ratio))

            W[i, j] = (1.0 / denom_i) * ratio_term

        row_sum_off = W[i, :].sum()
        W[i, i] = 1.0 - row_sum_off

    W = np.clip(W, 0.0, 1.0)
    rowsum = W.sum(axis=1, keepdims=True)
    rowsum[rowsum == 0.0] = 1.0
    W = W / rowsum
    return W

def build_W_by_topology(n: int,
                        topology: str = "complete",
                        lambdas: Optional[Iterable[float]] = None,
                        **kwargs) -> np.ndarray:
    topo = topology.lower()
    if topo == "complete":
        A = complete_graph(n)
    elif topo == "static":
        A = static_exponential_graph(n)
    elif topo == "ring":
        k = int(kwargs.get("k", 1))
        A = ring_graph(n, k=k)
    elif topo == "star":
        center = int(kwargs.get("center", 0))
        A = star_graph(n, center=center)
    elif topo in ("er", "erdos"):
        p = float(kwargs.get("p", 0.2))
        seed = kwargs.get("seed", None)
        A = erdos_renyi_graph(n, p=p, seed=42)
    elif topo in ("rgg", "geo"):
        radius = float(kwargs.get("radius", 0.3))
        seed = kwargs.get("seed", None)
        A = random_geometric_graph(n, radius=radius, seed=42)
    elif topo == "grid":
        m = int(kwargs.get("m", int(np.sqrt(n))))
        nn = int(kwargs.get("n", int(np.ceil(n / max(1, m)))))
        connectivity = int(kwargs.get("connectivity", 4))
        A = grid_2d_graph(m, nn, connectivity=connectivity)
    elif topo == "custom1":
        A = np.array([
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
            [0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0],
            [0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0],
            [0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
            [0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1],
            [0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0],
            [0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1],
            [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0]
        ], dtype=int)
    elif topo == "custom2":
        A = np.array([
            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
            [0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0],
            [0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
            [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
            [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1],
            [0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0],
            [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0],
            [0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1],
            [1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1],
            [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0]
        ], dtype=int)

    else:
        raise ValueError(f"Unknown topology: {topology}")

    return build_weight_matrix(A, lambdas=lambdas, epsilon=0.3)


