

import numpy as np
import scipy as sp 
from jax import random
import jax.numpy as jnp
import matplotlib.pyplot as plt

def adjacency_matrix(tri):
    """
    Compute the adjacency matrix of the Delaunay triangulation.
    
    Args:
        tri (scipy.spatial.Delaunay): A Delaunay triangulation object generated by SciPy's Delaunay function.
        
    Returns:
        numpy.ndarray: A square adjacency matrix where the element at [i, j] is 1 if vertices i and j are adjacent, and 0 otherwise.
        
    Example:
        >>> import scipy.spatial
        >>> tri = scipy.spatial.Delaunay(points)
        >>> adjacency_matrix(tri)
    """
    n = len(tri.points)
    matrix = np.zeros((n, n), dtype=int)
    for simplex in tri.simplices:
        for i in simplex:
            for j in simplex:
                matrix[i, j] = 1
    np.fill_diagonal(matrix, 0)
    return matrix

def plot_graph(points, adj_matrix):
    """
    Plot the graph of the Delaunay triangulation based on the given points and adjacency matrix.
    
    Args:
        points (numpy.ndarray): An array of points representing the vertices of the graph, where each point is represented by a pair (x, y).
        adj_matrix (numpy.ndarray): A square adjacency matrix representing the graph of the Delaunay triangulation.
        
    Example:
        >>> points = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
        >>> adj_matrix = np.array([[0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]])
        >>> plot_graph(points, adj_matrix)
    """
    plt.scatter(points[:, 0], points[:, 1], s=50, color='k', zorder=1)
    for i in range(len(points)):
        for j in range(len(points)):
            if adj_matrix[i, j]:
                plt.plot([points[i, 0], points[j, 0]], 
                         [points[i, 1], points[j, 1]], 'b-', alpha=1, zorder=0, linewidth=1)

def floyd_warshall(A):
    """
    Execute the Floyd-Warshall algorithm to compute the geodesic distance according to edge length 1/A.
    
    Args:
        A (numpy.ndarray): A square adjacency matrix where element [i, j] represents the weight of the edge between vertices i and j.
        
    Returns:
        numpy.ndarray: A matrix D of the same shape as A, where D[i, j] represents the shortest distance between vertices i and j.
        
    Notes:
        - The edge length is computed as 1/A.
        - Non-connected vertices are assigned a distance close to infinity.
        
    Example:
        >>> A = np.array([[0, 1, inf, inf], [1, 0, 1, inf], [inf, 1, 0, 1], [inf, inf, 1, 0]])
        >>> floyd_warshall(A)
    """
    n = A.shape[0]
    G = 1 / (A+1e-5)  # put almost infinity for non connected vertices 
    G = G-np.diag(np.diag(G)) # zero on the diagonal
    D = list(map(lambda i: list(map(lambda j: j, i)), G))
    # Adding vertices individually
    for k in range(n):
        for i in range(n):
            for j in range(n):
                D[i][j] = min(D[i][j], D[i][k] + D[k][j])
    return np.reshape(D, (n,n))

def generate_graph(graph_type,n):
    """
    Generate a graph and its adjacency matrix based on the specified graph type and number of vertices.
    
    Args:
        graph_type (str): The type of graph to generate. It should be one of the following: 
            'planar': Generates a planar graph using Delaunay triangulation.
            'erdos': Generates a graph using the Erdős–Rényi model with a probability `p` for edge creation.
            'rbm': Generates a Random Block Model graph with different probabilities for intra-class and inter-class edges.
            'circular': Generates a circular graph where each vertex is connected to its two neighbors.
        n (int): The number of vertices in the graph.
        
    Returns:
        tuple: A tuple containing:
            - A (numpy.ndarray): The adjacency matrix of the generated graph.
            - X (numpy.ndarray): The array representing the 2D coordinates of the vertices in the graph.
    
    Example:
        >>> A, X = generate_graph('planar', 50)
        
    Note:
        For 'planar' graph_type, it is optional to use the inverse Euclidean distance for the adjacency to be more "geometric".
        The function uses a predefined random seed to ensure the reproducibility of the graph generation process.
    """
    key = random.PRNGKey(758493321) 
    X = []
    if graph_type=='planar':
        X = np.random.rand(n, 2)
        T = sp.spatial.Delaunay(X)
        A = adjacency_matrix(T)        
        # optional: use inverse Euclidean distance for the adjacency (to be even more "geometric")
        if 0:
            D_eucl = np.sqrt( np.sum( ( X[:,None,:] - X[None,:,:] )**2, axis=2) )
            A = A * 1/(D_eucl+1e-5) # you can uncomment
    elif graph_type=='erdos':    
        p = .1 # probability of having an edge
        A = ( random.uniform(key, shape=(n,n)) > 1-p ).astype(jnp.float32)
        A = jnp.maximum(A,A.T)
        #X = np.random.randn(n,2)
    elif graph_type=='rbm': # random block model
        p = .1 # probability of having an edge in the same class
        q = .02 # probability of having an edge between the two classes
        A = np.zeros((n,n))
        A[:n//2,:n//2] = ( random.uniform(key, shape=(n//2,n//2)) > 1-p ).astype(jnp.float32)
        A[n//2:,n//2:] = ( random.uniform(key, shape=(n//2,n//2)) > 1-p ).astype(jnp.float32)
        A[:n//2,n//2:] = ( random.uniform(key, shape=(n//2,n//2)) > 1-q ).astype(jnp.float32)
        A = jnp.maximum(A,A.T)
        X = np.random.randn(n,2)
        X[:n//2,:] = X[:n//2,:]-3
        X[n//2:,:] = X[n//2:,:]+3
    elif graph_type=='circular': # usual laplacian
        A = np.zeros((n, n), dtype=int)
        A[np.arange(n-1), np.arange(1, n)] = 1
        A[np.arange(1, n), np.arange(n-1)] = 1
        A[0, -1] = 1
        A[-1, 0] = 1
        z = np.exp( 2*1j* np.pi * np.arange(n)/n )
        X = np.vstack( (np.real(z),np.imag(z)) ).T
        
    if X==[]:
        import networkx as nx
        G = nx.from_numpy_matrix(A)
        X = nx.spring_layout(G)
        X = np.array([X[node] for node in sorted(X.keys())])

    return A,X