import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from scipy.spatial.distance import cdist
import torch
rfp = "random_filter_params:"

DEBUG = False


# given adjacency matrix of undirected graph, return sparsity
def calc_sparsity(A):
    assert isinstance(A,np.ndarray), f'calc_sparsity: input must be np array'
    A_shape = A.shape
    assert len(A_shape)==2, f'calc_sparsity: A must be a 2D array, is {len(A_shape)}D array'
    assert A_shape[0]==A_shape[1] and A_shape[0]>=1, f'calc_sparsity: must be square matrix, is {A_shape}'
    N = A_shape[0]
    assert (np.diagonal(A)==np.zeros(N)).all(), f'calc_sparsity: must be square matrix, is {A_shape}'
    assert np.allclose(A, A.T), f'calc_sparsity: A must be undirected (symmetric), is not'

    # find number of nonzero entries in A -> assumed not directed, no self loops
    num_edges = np.count_nonzero(A)/2
    total_possible_edges = N*(N-1)/2

    sparsity = num_edges/total_possible_edges

    return sparsity


def all_pair_shortest_path_lengths(G,N):
    shortest_path_lengths_dict = dict(nx.shortest_path_length(G))  # source,target not specified
    shortest_path_lengths_np = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            shortest_path_lengths_np[i,j] = shortest_path_lengths_dict[i][j]
    return shortest_path_lengths_np


def sbm_constructor(num_vertices, num_communities, p_in, p_out):
    assert 0 <= p_out <= p_in <= 1, f'invalid p_in, p_out = {p_in}, {p_out}'
    # assert num_vertices%num_communities == 0, f'Number of specified communites {num_communities} must evenly divide num_vertices {num_vertices} (for now)'
    sizes = [int(num_vertices / num_communities)] * num_communities
    if not (num_vertices % num_communities == 0):
        # num_vertices not perfectly divisible by num_communites. Distribute leftover vertices as evenly as
        # possible
        for i in range(num_vertices - sum(sizes)):
            # add one to each size until no more leftover
            j = i % len(sizes)
            sizes[j] += 1

    assert sum(sizes) == num_vertices
    size_diff = [[abs(size - other_size) <= 1 for other_size in sizes] for size in sizes]
    assert all(size_diff), f'difference bewtween community sizes can be at most one, {sizes}'

    # matrix of edge connection probs
    blocks = [torch.ones(size, size)*p_in for size in sizes]
    prob_matrix = torch.block_diag(*blocks)
    prob_matrix[prob_matrix==0] = p_out
    return sizes, prob_matrix

# Inputs: N:: int number of vertices, r::L2 distance to use for edge
# creation in graph, dim :: dim to use in graph creation, sparse_thresh ::
# float cutoff for % of edges allowed to exist in G
# Output  G :: networkx graph representation
def connected_sparse_gen(num_vertices, r, dim=2, sparse_thresh_low=0.0, sparse_thresh_high=1.0, graph_gen='geom', max_attempts=20):
    assert num_vertices>2, f'number of nodes {num_vertices} must be >3 for sensical graph'
    poss_edges = (num_vertices*(num_vertices-1))/2
    attempts = 1
    while True:
        G = {}
        if graph_gen == 'ER':
            G = nx.fast_gnp_random_graph(num_vertices, r)
        elif graph_gen == 'geom':
            G = nx.random_geometric_graph(num_vertices, r, dim=dim)
        elif graph_gen == 'pref_attach':
            # num_vertices = 68, m = 28 -> sparsity of ~1/2
            G = nx.barabasi_albert_graph(n=num_vertices, m=r)#, seed=attempts if attempts>1 else None)
        elif graph_gen == 'sbm':
            assert type(r) == type({})
            num_communities, p_in, p_out = r['num_communities'], r['p_in'], r['p_out']
            sizes, prob_matrix = sbm_constructor(num_vertices, num_communities, p_in, p_out)
            probs = []
            for i in range(num_communities):
                probs_from_i = [(p_in if i == j else p_out) for j in range(num_communities)]
                probs.append(probs_from_i)
            G = nx.stochastic_block_model(sizes=sizes, p=probs,
                                          nodelist=range(sum(sizes)), # This should ensure consistent node labeling
                                          directed=False, selfloops=False)
        else:
            input(f'connected_sparse_gen: No valid graph generator given ({graph_gen}). Exit.')
        connected = nx.is_connected(G)
        sparsity = (G.number_of_edges()/poss_edges)
        sparse = sparse_thresh_low <= sparsity <= sparse_thresh_high
        if connected and sparse:
            return G, attempts, sparsity
        elif graph_gen == 'pref_attach':
            raise ValueError(f'Preferential Attachment model was not able to meet sparsity requirements. Either m is too low, or it is not possible."')
        #print(f'attempt {attempts}: connected? {connected}, sparsity: {sparsity}')
        attempts += 1
        if attempts > max_attempts:
            print(f'\tfailed after {attempts} attempts, with coonected?({connected} &  sparsity of {sparsity}')
            raise ValueError(f'Sampled > {max_attempts} graphs without connectivity/sparsity constraints satisfied. Adjust parameters.')


def search_appropriate_sampling_params(graph_gen, N, param_vals, sparsity_range = (0.5, 0.6), rand_seed=50):

    sampling_rounds = 10
    max_attempts = 10
    for pv in param_vals:
        total_attempts_made = 0
        try:
            sparsities = []
            for i in range(sampling_rounds):
                G, attempts, sparsity = \
                    connected_sparse_gen(num_vertices=N,
                                         r=pv,
                                         sparse_thresh_low=sparsity_range[0],
                                         sparse_thresh_high=sparsity_range[1],
                                         graph_gen=graph_gen,
                                         max_attempts=max_attempts)
                sparsities.append(sparsity)
                total_attempts_made += attempts
                #print(i, attempts)

            ave_num_tries_needed = total_attempts_made/sampling_rounds
            # failure_rate = total # attempts made / total number of attempts we could have made
            #failure_rate = total_attempts_made/(max_attempts*sampling_rounds)
            print(f'{pv} works: num samples for success: {ave_num_tries_needed}, ave_sparsity: {np.mean(sparsities):.3f}, std_sparsity: {np.std(sparsities):.3f}')
        except Exception as e:
            print(e)
            print(f'{pv} NOT work: success rate 0%')


def add_edge_weights(G,A,N, dim=2, applyTransform=False):
    #Each node has a node attribute 'pos' that stores the position of that
    # node in Euclidean space as provided by the pos keyword argument or,
    # if pos was not provided, as generated by this function.
    A_w = np.zeros((N,N))
    pos = nx.get_node_attributes(G,'pos')
    positions = np.zeros((N,dim))
    for key in pos:
        positions[key] = pos[key]
        #print(f'pos of {key}: {pos[key]}')

    #euclidean distance matrix: dist[i,j] = euclid distance btwn xi and xj
    dist = cdist(positions,positions)

    #recover support of A => restrict max to only edges in graph
    A_w = np.multiply(dist,A)
    #shift and scale
    A_w = np.divide(A_w, np.max(A_w))*8 + ( np.ones((N,N)) - np.eye(N) )
    #recover support of graph
    A_w = np.multiply(A,A_w)

    if applyTransform:
        c, stdv = 0.76, 0.1
        #use gaussian kernel to smooth edge weights
        print(f'A:\n{A}')
        print(f'A_w:\n{A_w}')
        A_w_smooth = np.multiply(A_w, np.exp( -1* np.square(A_w-c) / stdv))
        #put on same scale as brain data
        print(f'A_w_smooth:\n{A_w_smooth}')
        A_w = np.divide(A_w_smooth, np.max(A_w_smooth))*8 + ( np.ones((N,N)) - np.eye(N) )
        print(f'A_w_smooth scaline and shift:\n{A_w}')
        #recover orignal support (0 edge weights -> nonzero by transformation)
        A_w = np.multiply( (A>0) + 0, A_w)
        print(f'A_w recover support:\n{A_w}\n\n')
        iu1 = np.triu_indices(N, k=1)
        sorted_distances = np.asarray(np.sort(A_w[iu1]))
        plt.hist(sorted_distances.flatten(),alpha = 0.5)
        plt.show()
        input('...')

    #TO DO: Choose best transform
    if False:
        print(f'Adj matrix:\n{A}')
        print(f'positions in Euclidean space of points:\n{pos}')
        print(f'Vector of Euclidean space of points: \n{positions}')
        print(f'Euclidean Distance Matrix: \n{dist}')
        print(f'weighted graph is subset of Euclid Dist Matrix:\n{A_w}')
        iu1 = np.triu_indices(N, k=1)
        sorted_distances = np.asarray(np.sort(A_w[iu1]))
        sorted_distances = sorted_distances[sorted_distances>0].flatten()
        print(f'sorted distances: shape {sorted_distances.shape}\n{sorted_distances}')
        s = 5
        centers = np.linspace(.15,.8,s)
        #centers = np.linspace(0,1,s)
        powers = [1,2,3,4]
        stdvs   = [0.001,.01, .1, .5, 1]
        print(f'means used: {centers}\n stdvs used: {stdvs}')

        #plot and see which does the best job of smoothing => mean = stdv = 0.28

        for i,c in enumerate(centers):
            for j, stdv in enumerate(stdvs):
            #for j, p in enumerate(powers):

                cols = len(stdvs)
                ax = plt.subplot(s,cols,i*cols+j+1)
                transform_gauss = np.multiply(sorted_distances, np.exp( -1*( (sorted_distances-c)**2) / stdv) )
                #transform_gauss = np.multiply(sorted_distances, 1/((sorted_distances+c)**p))
                #transform_gauss = np.multiply(1, np.exp( -1*( (sorted_distances-c)**2) / stdv) )
                ax.hist(sorted_distances,alpha = 0.5)
                ax.hist(transform_gauss, alpha = 0.5)
        plt.show()

        c = stdv = 0.28
        transform_gauss = np.multiply(sorted_distances, np.exp( ( (sorted_distances-c)**2) / stdv) )
        #want it to be between 1 and 9  like brain data
        transform_gauss = (transform_gauss/np.max(transform_gauss)*8 + 1)
        plt.hist(sorted_distances,alpha = 0.5)
        #plt.hist(transform_gauss, alpha = 0.5)
        plt.show()
        input('inspect range: way to transform so mean/stdv is similar to brain?')
    return A_w


# Inputs: G :: networkx graph representation, N:: int number of vertices
# Output  A :: numpy array of adjacency matrix, L :: numpy array of laplacian matrix
def create_GSO(G):
    A = np.array(nx.to_numpy_matrix(G))
    D = np.diag(np.sum(A, axis=0))
    """
    ones    = np.ones(N, dtype=int)
    degrees = np.matmul(A, ones)
    degrees = np.resize(degrees, (N,)) # make 1D
    D       = np.diag(degrees)
    """
    L = D - A

    return A, L


#Inputs: G :: networkx graph representation, signal :: [Float] graph signal
#plot with networkx:
#Output: None
# networkx drawing api:
# https://networkx.github.io/documentation/stable/reference/drawing.html
def plot_graph(G, signal):
    #view graph with multiple layouts
    positions = [nx.circular_layout(G), nx.shell_layout(G), nx.spring_layout(G), nx.spectral_layout(G)]
    algs      = ["circular_layout",     "shell_layout",     "spring_layout",     "spectral_layout"]

    for i, pos in enumerate(positions):
        ax = plt.subplot(1,len(positions),i+1) #subplot indices start at 1, not 0
        ax.title.set_text(algs[i])
        nx.draw(G, pos, with_labels=True, font_weight='bold', node_color = signal)

    plt.show()


def calc_sparsity_TESTS():
    #testing calc_sparsity
    N = 3
    zeros = np.zeros((N,N))
    assert np.allclose(calc_sparsity(zeros),0), f'matrix with all zeros should have zero sparsity'

    one_edge = np.zeros((N,N))
    one_edge[0,1] = one_edge[1,0] = 1
    assert np.allclose(calc_sparsity(one_edge),1/3)
    one_edge[0,2] = one_edge[2,0] = 1
    assert np.allclose(calc_sparsity(one_edge),2/3)
    one_edge[1,2] = one_edge[2,1] = 1
    assert np.allclose(calc_sparsity(one_edge),1)

    N = 4
    two = np.zeros((N,N))
    two[0,1] = two[1,0] = 1
    assert np.allclose(calc_sparsity(two),1/6)
    two[0,2] = two[2,0] = 1
    assert np.allclose(calc_sparsity(two),2/6)
    two[0,3] = two[3,0] = 1
    assert np.allclose(calc_sparsity(two),3/6)
    two[1,2] = two[2,1] = 1
    assert np.allclose(calc_sparsity(two),4/6)
    two[1,3] = two[3,1] = 1
    assert np.allclose(calc_sparsity(two),5/6)
    two[2,3] = two[3,2] = 1
    assert np.allclose(calc_sparsity(two),1)


if __name__ == "__main__":

    np.random.seed(50)
    # Note: Pref-attachment sparsities will always be the same for given N/M!
    param_vals = [10, 15, 20, 25]
    search_appropriate_sampling_params(graph_gen='pref_attach', N=68, param_vals=param_vals, sparsity_range=(0, 1))

    exit(2)


    num_communities = 3
    p_in, p_out = 0.8, 0.2
    param_vals = [(num_communities, 0.8, 0.2), (num_communities, 0.8, 0.3)]
    search_appropriate_sampling_params(graph_gen='sbm', N=68, param_vals=param_vals, sparsity_range=(0,1))
    # experiments to find the appropriate sampling parameter r for sampling geometric graphs of a particular size and sparsity.
    """
    # N=68: .56--.58 -? 89% success rate
    search_appropriate_sampling_params(graph_gen='geom', N=68, param_vals=(.51, 0.53, 0.54, 0.57, .58), sparsity_range=(.5, 0.6))

    # N=500: .53--.56 -? 90% success rate
    search_appropriate_sampling_params(graph_gen='geom', N=500, param_vals=(.51, 0.53, 0.54, 0.57, .58), sparsity_range=(.5, 0.6))

    # N=1000 .54-.57 -> 90% success rate
    search_appropriate_sampling_params(graph_gen='geom', N=1000, param_vals=(.51, 0.53, 0.54, 0.57, .58), sparsity_range=(.5, 0.6))

    # N=10000 .53-.57 -> 90% success rate
    search_appropriate_sampling_params(graph_gen='geom', N=1000, param_vals=(.51, 0.53, 0.54, 0.57, .58), sparsity_range=(.5, 0.6))
    """

    # find num_communities, p_in, p_out such that we can effectively sample SBM graphs with specified sparsity.
    graph_gen = 'sbm'
    N = 68
    num_communities = 3
    for num_communities in np.arange(2, 6, 2):
        print('\tnum_communities:', num_communities)
        for p_in in np.arange(.8, .95, .05):
            for p_out in np.arange(.2, .5, .05):
                try:
                    G, attempts, sparsity = connected_sparse_gen(num_vertices=N,
                                                       r=(num_communities, p_in, p_out),
                                                       sparse_thresh_low=0,
                                                       sparse_thresh_high=1,
                                                       max_attempts=20,
                                                       graph_gen=graph_gen
                                                       )  # 'pref_attach')
                    A, L = create_GSO(G)
                    sparsity = np.sum(A)/(N*(N-1))
                    sparsity_ = calc_sparsity(torch.tensor(A))
                    print(f'num_communities/p_in/p_out = {num_communities}/{p_in:.3f}/{p_out:.3f} -> sparsity {sparsity:.3f} works!')
                except:
                    print(f'num_communities/p_in/p_out = {num_communities}/{p_in:.3f}/{p_out:.3f} does NOT work! See above.')

    exit(2)

    # find m such that we can effectively sample Barb_Albert graphs with specified sparsity.
    m = 25
    graph_gen = 'geom' # 'pref_attach'
    for m in np.arange(.4, .6, .025):
        try:
            for i in range(10):
                G, attempts, sparsity = connected_sparse_gen(num_vertices=68, r=m, sparse_thresh_low=.50, sparse_thresh_high=.6,
                                                   graph_gen=graph_gen)
                A, L = create_GSO(G)
            print(f'm of {m} works!')
        except:
            print(f'm of {m} not enough')

    exit(2)

    calc_sparsity_TESTS()



    #                   [dataloader]
    #                 //      ||        \\
    #               //        ||         \\
    #             //          ||          \\
    # [ER generator]   [Construct Filter]  [gen_white_signals]
    #                         ||
    #                [rand_filter_params]

    # N = # vertices, M = # edges

    #[random filter parameters]
    # Module Responsibility: produce a random order polynomial and random coefficients
    # Assumptions:
    #  Inputs:
    #       # N is an integer >=2
            # sigma is a float >0
    #  -Correct distrubtion over order and coefficients
    #     Testing: rely upon numpy to be correct





    #[ER generator]
    # Module Responsibility: produce ER random graphs
    # Assumptions:
    #  Inputs:
    #       # N is an integer >=2
    #       # p is a float in [0,1]
    #  -correct sampling technique for er

    #Using nx implimenation. Assuming this is correct


    if False:
        ## test new geom_connected_sparse function
        N , r, dim, sparse_thresh = 30, .25, 2, 0.15
        for j in range(100):
            G, attempts, sparisty = connected_sparse_gen(N, r=0.17, dim=dim, \
                    sparse_thresh_high=sparse_thresh,graph_gen='ER')
            G, attempts, sparsity = connected_sparse_gen(N, r=0.25, dim=dim, \
                    sparse_thresh_high=sparse_thresh,graph_gen='geom')
            connected = nx.is_connected(G)
            poss_edges = ((N**2)/2)
            sparse    = ((G.number_of_edges()/poss_edges) < sparse_thresh)
            assert connected and sparse, \
                    f'geom_connected_sparse: not returning connected and sparse graphs'


    #create weighted graphs from the geom graphs
    N, r, dim = 68, .56, 2
    for i in range(5):
        G = nx.random_geometric_graph(N, r, dim=dim)
        A, L = create_GSO(G,N)
        A_w = add_edge_weights(G,A,N)#,applyTransform=True)
        print(f'Adj Matrix: \n{A}')
        print(f'Weighted Adj Matrix: \n{A_w}')
        input('inspect')

    #Diffuse signals with weighted adj matrix
