import matplotlib.pyplot as plt
import numpy as np
import scipy.linalg
import scipy.sparse

def imshow(x, ticks=None, title=''):
    n_vars = len(x)
    plt.imshow(x)
    if ticks is None:
        ticks = range(n_vars)
    plt.yticks(range(n_vars), ticks)
    plt.xticks(range(n_vars), ticks)
    plt.title(title)
    plt.show()

def allpairsdistance(X, Y):
    '''
    Compute Euclidean distance between each pair of elements (x, y)
    '''
    assert X.ndim == Y.ndim and len(X) == len(Y)
    if X.ndim < 2:
        # X and Y are 1-D; use absolute differences
        dist = np.abs(X[:, np.newaxis] - Y)
    else:
        # Minor speedup if we use the fact that (x-y)^2 = x^2 + y^2 - 2xy
        y_2 = np.sum(Y**2, axis=1)
        if X is Y:
            # Don't bother computing both x^2 AND y^2, since they're the same
            x_2 = y_2[:, np.newaxis]
        else:
            x_2 = np.sum(X**2, axis=1)[:, np.newaxis]
        xy2 = 2 * np.dot(X, Y.T)
        dist = np.sqrt(x_2 + y_2 - xy2)
    return dist

def center_matrix(x):
    '''Center matrix to make each row and column have zero mean'''
    row_means = np.mean(x, axis=0)
    col_means = np.mean(x, axis=1)
    grand_mean = np.mean(x)
    return ((x - row_means) - col_means[:,np.newaxis]) + grand_mean

# Shuffle the row/column indices symmetrically
def arrange_vars(A, order):
    n_vars = len(A)
    B = np.zeros_like(A)
    for row in range(n_vars):
        B[row,:] = A[order[row], order]
    return B

def shuffle_vars(A, seed=None):
    n_vars = len(A)
    indices = np.arange(n_vars)
    np.random.seed(seed)
    np.random.shuffle(indices)
    B = arrange_vars(A, indices)
    return B

# Define Cuthill-McKee algorithm
def compute_degree(A, var):
    return np.sum(A[:,var])

def adjacency_set(A, var):
    return [i for (i, node) in enumerate(A[:,var]) if (i != var and node > 0)]

def cuthill_mckee(A, reverse=False):
    N = len(A)
    degrees = [compute_degree(A, i) for i in range(N)]
    R = [np.argmin(degrees)] if not reverse else [np.argmax(degrees)]
    i=0
    while len(R) < N:
        if i < len(R):
            adjacents = sorted([a for a in adjacency_set(A, R[i]) if a not in R], key=lambda x: degrees[x], reverse=reverse)
            R += adjacents
            i += 1
        else:
            remaining = [i for i in range(N) if i not in R]
            next = np.argmin([degrees[i] if i in remaining else N+1 for i in range(N)])
            R += [next]
    return R

def modified_cuthill_mckee(A, reverse=False):
    N = len(A)
    degrees = [compute_degree(A, i) for i in range(N)]
    R = [np.argmin(degrees)] if not reverse else [np.argmax(degrees)]
    i=0
    while len(R) < N:
        if i < len(R):
            # Same connected component; add adjacent nodes
            frontier = [a for a in adjacency_set(A, R[i]) if a not in R]
            # Sort by node degree
            frontier = sorted(frontier, key=lambda x: degrees[x], reverse=reverse)
            # Sort by node value
            frontier = sorted(frontier, key=lambda x: A[R[i],x], reverse=True)
            R += frontier
            i += 1
        else:
            # Disconnected graph; pick new starting node with lowest degree
            remaining = [i for i in range(N) if i not in R]
            next = np.argmin([degrees[i] if i in remaining else N+1 for i in range(N)])
            R += [next]
    return R

def test_block_diag_easy():
    # Create block diag matrix with overlap
    block_sizes = (3,5)
    s1, s2 = block_sizes
    blockS1 = np.kron(np.eye(s2), np.ones((s1,s1)))
    blockS2 = np.kron(np.eye(s1), np.ones((s2,s2)))
    A = np.logical_or(blockS1, blockS2).astype(float)
    density = np.sum(A)/np.size(A)
    imshow(A, title='A block-diagonal-with-overlap (BDO) matrix')

    #%%
    mi = shuffle_vars(A, seed=0)
    assert np.sum(mi) == np.sum(A)
    imshow(mi, title='Pairwise Mutual Information')

    #%%
    # Use Cuthill-McKee to permute back to BDO matrix
    # order = [0, 6, 12, 5, 11, 13, 1, 2, 9, 3, 8, 10, 4, 7, 14]
    order = modified_cuthill_mckee(mi)
    fixed = arrange_vars(mi, order)
    imshow(fixed, ticks=order, title='Permutation to nearly block-diagonal')

def test_random_symmetric():
    # Generate a new mutual information matrix
    np.random.seed(0)
    n_vars = 15
    density = 0.3
    A = scipy.sparse.random(n_vars, n_vars, density=density, dtype=float).todense()
    A = (A + A.transpose())/2 + np.eye(n_vars) > 0.25
    A = A.astype(bool).astype(float)
    imshow(A, title='Pairwise Mutual Information (#2)')

    #%%
    # Use Cuthill_McKee to minimize matrix bandwidth
    order = cuthill_mckee(A)
    fixed = arrange_vars(A, order)
    imshow(fixed, title='Cuthill-McKee Ordering', ticks=order)

    #%%
    order = modified_cuthill_mckee(A)
    fixed = arrange_vars(A, order)
    imshow(fixed, title='Modified Cuthill-McKee Ordering', ticks=order)

def test_disconnected():
    blocks = [np.ones((i+1,i+1)) for i in range(1,6)]
    D = scipy.linalg.block_diag(*blocks)
    # for i in range(len(D)-1):
    #     D[i,i+1] = D[i+1,i] = 1
    mi = shuffle_vars(D,seed=0)
    imshow(mi)

    order = cuthill_mckee(mi, reverse=False)
    fixed = arrange_vars(mi, order)
    imshow(fixed, ticks=order, title='Cuthill-McKee')

def test_modified_CM():
    # Generate a new mutual information matrix
    np.random.seed(9)
    density = .36/1.7
    n_vars = 15
    A = 2*scipy.sparse.random(n_vars, n_vars, density=density, dtype=float).todense()
    A = np.minimum(1,(A + A.transpose())/2 + np.eye(n_vars))
    # A = A.astype(bool).astype(float)
    imshow(A, title='Pairwise Mutual Information')

    order = cuthill_mckee(A, reverse=False)
    fixed = arrange_vars(A, order)
    imshow(fixed, ticks=order, title='Cuthill-McKee')

    order = modified_cuthill_mckee(A, reverse=False)
    fixed = arrange_vars(A, order)
    imshow(fixed, ticks=order, title='Modified Cuthill-McKee')

def main():
    test_block_diag_easy()
    test_random_symmetric()
    test_disconnected()
    test_modified_CM()

if __name__ == '__main__':
    main()
