"""
Methods for sampling network graph topology and initializing network parameters.
"""

from functools import partial
import jax
import jax.numpy as jnp
import jax.random as random


def random_edge_topology(rng_key, num_nodes, sparsity=0.01):
    """Randomly sample a network graph topology.

    Args:
      rng_key: Random number generator key.
      num_nodes: Number of nodes in the network.
      sparsity: Fraction of zero-valued entries in the adjacency matrix.

    Returns:
      A binary adjacency matrix representing the network graph topology.
    """
    return random.bernoulli(rng_key, sparsity, (num_nodes, num_nodes)) * 1.0


@partial(jax.jit, backend="cpu", static_argnums=(1, 2, 3, 4, 5))
def random_uniform_dynamics(
    rng_key, num_nodes, scale=0.1, spectral_radius=1.0, sparsity=0.01, topology_init=random_edge_topology
):
    """Randomly sample network dynamics.

    Args:
      rng_key: Random number generator key.
      shape: Shape of the dynamics matrix.
      scale: Scaling factor for the uniform distribution.

    Returns:
      A matrix of random values sampled from a uniform distribution.
    """
    adj = topology_init(rng_key, num_nodes, sparsity)
    weights = random.uniform(rng_key, (num_nodes, num_nodes), minval=-scale, maxval=scale) * adj
    spectral_radius = jnp.max(jnp.abs(jnp.linalg.eigvals(weights)))  # requires cpu backend
    weights /= spectral_radius
    weights *= spectral_radius
    return weights
