import numpy as np
import networkx as nx

from scipy.linalg import block_diag


def stochastic_block_model(
    n: int, 
    K: int, 
    p: float, 
    q: float,
    seed: int | None = None
) -> np.ndarray:
    """
    Obtain the adjacency matrix of a graph from a 
    stochastic block model.

    Parameters
    ----------
    n: int
        The number of nodes in the graph.
    
    K: int
        The number of communities in the graph.

    p: float
        The intracluster edge probability.

    q: float
        The intercluster edge probability.

    seed: int
        The seed to the random number generator.

    Returns
    -------
    A: ndarray
        The adjacency matrix of the graph.

    """
    np.random.seed(seed)

    # Sizes of each community (evenly split)
    sizes = [n // K] * K
    for i in range(n % K):
        sizes[i] += 1

    # Define the block probability matrix
    prob_matrix = [[p if i == j else q for j in range(K)] for i in range(K)]

    # Generate the SBM graph
    G = nx.stochastic_block_model(sizes, prob_matrix, seed=seed)
    #return G.adjacency_list()

    return nx.to_numpy_array(G)


def expected_sbm(
    n: int, 
    K: int, 
    p: float, 
    q: float,    
):

    m = int(n/K)
    B = p*np.ones((m,m))

    A = block_diag(*([B]*K))
    A[A == 0] = q

    return A
    

    
