from enum import unique
from sys import _xoptions
import numpy as np
import networkx as nx
import copy
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.cluster import KMeans, AgglomerativeClustering
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import pairwise_distances
import scipy as sp
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
# import kmeans1d
import time
import os
import cvxpy
from karateclub import Role2Vec, Node2Vec

class HashFunction:
    def __init__(self):
        self.reset()

    def apply(self, value):
        if value not in self.hash_dict:
            self.hash_dict[value] = self.hash_counter
            self.hash_counter += 1
        return self.hash_dict[value]

    def reset(self):
        self.hash_dict = {}
        self.hash_counter = 2


wl_hash = HashFunction()

def make_symmetric(A):
    for i in range(A.shape[0]):
        for j in range(i+1,A.shape[1]):
            A[j,i] = A[i,j]
    return A

def is_equiv_subroutine(c1, c2, return_cmap=False):
    color_map = {}
    for i in range(len(c1)):
        if c1[i] not in color_map:
            color_map[c1[i]] = c2[i]
        else:
            if color_map[c1[i]] != c2[i]:
                if return_cmap:
                    return False, color_map
                else:
                    return False
    if return_cmap:
        return True, color_map 
    else:           
        return True

def is_equivalent(c1, c2):
    return is_equiv_subroutine(c1, c2) and is_equiv_subroutine(c2, c1)

def weisfeiler_lehman(graph1: nx.Graph, iterations=-1, early_stopping=True, hash=wl_hash):
    if iterations == -1:
        iterations = len(graph1)

    Gamma1 = np.ones(len(graph1), dtype=int)
    set_colors_by_iteration = []
    colors_by_iteration = []

    for t in range(iterations):
        tmp_Gamma1 = np.copy(Gamma1)
        colors_by_iteration.append(copy.deepcopy(Gamma1))
        set_colors_by_iteration.append(set(Gamma1))
        for node in range(len(graph1)):
            Gamma1[node] = hash.apply((Gamma1[node], tuple(sorted([tmp_Gamma1[n] for n in graph1[node]]))))
        if is_equivalent(Gamma1, tmp_Gamma1) and early_stopping:
            return tmp_Gamma1, t, set_colors_by_iteration, colors_by_iteration

    colors_by_iteration.append(copy.deepcopy(Gamma1))
    set_colors_by_iteration.append(set(Gamma1))
    return Gamma1, iterations, set_colors_by_iteration, colors_by_iteration

def indicator_matrix_from_colors(colors, normalized=False, padding=0):
    if not normalized:
        H = np.array([[1 if i == colors[j] else 0 for j in range(len(colors))] for i in np.unique(colors)]).transpose()
    else:
        c, counts = np.unique(colors, return_counts=True)
        H = np.array([[1/np.sqrt(counts[j]) if colors[i] == c[j] else 0 for i in range(len(colors))] for j in range(len(c))]).transpose()
        
    return H if H.shape[1] >= padding else np.concatenate([H, np.zeros((H.shape[0], padding - H.shape[1]))], axis=1)

    

def kmedians1d(x, k, verbose=0):
    sort_idx = np.argsort(x)
    reverse = np.argsort(sort_idx)
    x = x[sort_idx]
    CC = np.zeros((len(x),len(x)))
    for i in range(CC.shape[0]):
        for j in range(i,CC.shape[1]):
            mu = np.mean([x[l] for l in range(i,j+1)])
            CC[i,j] = np.sum([np.abs(x[l]-mu) for l in range(i,j+1)])
    D = np.zeros_like(CC)
    T = np.zeros_like(CC, dtype=int)
    for i in range(CC.shape[1]):
        D[0][i] = CC[0, i]
        T[i][i] = i

    for i in range(1, k):
        for m in range(i+1,CC.shape[1]):
            idx = np.argmin([D[i-1][j-1] + CC[j,m] for j in range(i,m+1)]) + i
            #print(idx, D[i-1, idx-1], CC[idx, m])
            D[i][m] = D[i-1][idx-1] + CC[idx,m]
            T[i][m] = int(idx) 

    if verbose >= 1:
        print(D, T)
    col = []
    k = k-1
    old_idx = len(x)-1
    cuts = [len(x)]
    for i in range(k):
        idx = T[k-i][old_idx]
        cuts.append(idx)
        old_idx = idx -1
    cuts.append(0)
    cuts.reverse()
    if verbose >= 1:
        print(cuts)
    for i in range(len(cuts)-1):
        col.extend([i]*(cuts[i+1]-cuts[i]))
    return np.array(col)[reverse]

def initialize_centers(x, num_k):
    N, D = x.shape
    centers = np.zeros((num_k, D))
    used_idx = []
    for k in range(num_k):
        idx = np.random.choice(N)
        while idx in used_idx:
            idx = np.random.choice(N)
        used_idx.append(idx)
        centers[k] = x[idx]
    return centers

def update_centers(x, r, K):
    N, D = x.shape
    centers = np.zeros((K, D))
    for k in range(K):
        centers[k] = r[:, k].dot(x) / r[:, k].sum()
    return centers

def square_dist(a, b):
    return (a - b) ** 2

def cost_func(x, r, centers, K):
    
    cost = 0
    for k in range(K):
        norm = np.linalg.norm(x - centers[k], 2)
        cost += (norm * np.expand_dims(r[:, k], axis=1) ).sum()
    return cost


def cluster_responsibilities(centers, x, beta):
    N, _ = x.shape
    K, D = centers.shape
    R = np.zeros((N, K))

    for n in range(N):        
        R[n] = np.exp(-beta * np.linalg.norm(centers - x[n], 2, axis=1)) 
    R /= R.sum(axis=1, keepdims=True)

    return R

def soft_k_means(x, K, max_iters=20, beta=1.):
    centers = initialize_centers(x, K)
    prev_cost = 0
    for _ in range(max_iters):
        r = cluster_responsibilities(centers, x, beta)
        centers = update_centers(x, r, K)
        cost = cost_func(x, r, centers, K)
        if np.abs(cost - prev_cost) < 1e-5:
            break
        prev_cost = cost
        
    return r#plot_k_means(x, r, K)

def directedSBM(block_size, omega):
    A = nx.to_numpy_array(nx.stochastic_block_model(block_size, make_symmetric(omega.transpose().copy()), seed=np.random.seed()))
    B = nx.to_numpy_array(nx.stochastic_block_model(block_size, make_symmetric(omega.copy()), seed=np.random.seed()))
    #print(A.shape, B.shape)
    for i in range(B.shape[0]):
        for j in range(i, B.shape[1]):
            A[i,j] = B[i,j]
    return A

def sample_planted_role_model(c,n=[40]*3, omega_1=[], return_omega=False, verbose=0, directed=True, p_out=0.05):
    if len(omega_1) < 2:
        omega_1 = np.random.rand(len(n),len(n))
    omega = p_out * np.ones((c * len(n), c *len(n)))
    for i in range(c):
        lower = i*len(n)
        upper = (i+1) * len(n)
        omega[lower:upper, lower:upper] = omega_1
    # tmp1 = np.concatenate([omega_1] + [0.05 * np.ones_like(omega_1)]*2, axis=1)
    # tmp2 = np.concatenate([0.05*np.ones_like(omega_1)] + [omega_1] + [0.05* np.ones_like(omega_1)]*1, axis=1)
    # tmp3 = np.concatenate([0.05*np.ones_like(omega_1)] *2 + [omega_1], axis=1)
    # omega = np.concatenate([tmp1, tmp2, tmp3], axis=0)
    if verbose >= 1:
        print(omega)
    #A = np.mean([nx.to_numpy_array(nx.stochastic_block_model([40]*9, omega))for i in range(2)], axis=0)
    if not directed:
        A = nx.to_numpy_array(nx.stochastic_block_model(n*c, make_symmetric(omega), seed=np.random.seed()))
    A = directedSBM(n*c, omega)#nx.to_numpy_array(nx.stochastic_block_model([40]*9, omega))
    labels = np.array([[i%3]*40 for i in range(9)]).flatten()
    if verbose >= 2:
        plt.plot(A @ np.ones(A.shape[1]))
        plt.show()
        plt.plot(sorted(A @ np.ones(A.shape[1])))
        plt.show()
        sns.heatmap(omega)
        plt.show()
        sns.heatmap(A)
        plt.show()
        print(omega_1)
        print(omega)
    if return_omega:
        return A, omega
    return A

def frac_WL(A, k, min_iter=2, n_iter=10, verbose=0, return_hard_coloring=False, update='mean', keep_clustering=False, early_stopping=0.1, lr=0.5, fit_function='bgm'):
    A = A/np.max(np.abs(A))
    H = 1/k * np.ones((A.shape[1], k))
    if fit_function == 'kmeans':
        kmeans = KMeans(n_clusters=k, n_init=1)
        fit = lambda x : min([(tmp_H, frac_ep_cost(A, tmp_H)) for tmp_H in [indicator_matrix_from_colors(kmeans.fit_predict(x)) for i in range(10)]], key=lambda y : y[1])[0]
    elif fit_function == 'average_linkage':
        fit = lambda x: indicator_matrix_from_colors(AgglomerativeClustering(n_clusters=k,linkage='average').fit_predict(x))
    elif fit_function == 'soft_kmeans':
        fit = lambda x : min([(tmp_H, frac_ep_cost(A, tmp_H)) for tmp_H in [soft_k_means(x, k) for i in range(10)]], key=lambda y : y[1])[0]
    elif fit_function == 'bgm':
        g = BayesianGaussianMixture(n_components=k, n_init=1)
        fit = lambda x : min([(tmp_H, frac_ep_cost(A, tmp_H)) for tmp_H in [g.fit(X).predict_proba(X) for i in range(10)]], key=lambda y : y[1])[0]
    else:
        print('default')
        g = GaussianMixture(n_components=k, n_init=1)
        fit = lambda x : min([(tmp_H, frac_ep_cost(A, tmp_H)) for tmp_H in [g.fit(X).predict_proba(X) for i in range(10)]], key=lambda y : y[1])[0]
    if update=='mean':
        X = (1-lr) * H + lr * A @ H
    else:
        X =  A @ H
    
    gmm_params = []
    hard_coloring = [0]*A.shape[0]
    for i in range(n_iter):
        if verbose >= 4:
            sns.heatmap(X)
            plt.show()
        

        tmp = fit(X)
        if tmp.shape[1] < k:
            tmp = np.concatenate(tmp, np.ones(H.shape[0], k-tmp.shape[1]))
        #gmm_params = g.means_, g.weights_, g.precisions_
        #tmp_colors = g.predict(X)
        #print(hard_coloring, tmp_colors)
        

        #H = 0.2 * H + 0.8 * tmp
        d = pairwise_distances(tmp.transpose(), H.transpose(), metric='l2')
        if verbose >= 3:
            print("tmp")
            sns.heatmap(tmp)
            plt.show()
            print("H")
            sns.heatmap(H)
            plt.show()
        row_ind, col_ind = linear_sum_assignment(d)
        P = np.zeros((k,k))
        P[row_ind, col_ind] = [1]*k

        if verbose >= 1:
            print(frac_ep_cost(A, H), frac_ep_cost(A, tmp))
        if i > min_iter and frac_ep_cost(A, H) - frac_ep_cost(A, tmp) <= early_stopping:
            break

        H = tmp @ P
        if verbose >= 3:
            print("P")
            sns.heatmap(P)
            plt.show()
            print("H")
            sns.heatmap(H)
            plt.show()
            print("X")
            sns.heatmap(X)
            plt.show()

       
        #hard_coloring = tmp_colors
        if update == 'append':
            X = np.concatenate([X,A @ H], axis=1)
        if update == 'mean':
            X = (1-lr) * X + (lr) * A @ H
        else:
            X = A @ H
        if verbose >= 2:
            print("H")
            sns.heatmap(H)
            plt.show()
        
    
    if return_hard_coloring:  
        return np.argmax(H, axis=1)
    return H if H.shape[1] >= k else np.concatenate([H, np.zeros((H.shape[0], k - H.shape[1]))])

def frac_kmeans_cost_function(V, H, norm="l2"):
    D = np.diag([1/x if x != 0 else 0 for x in np.ones(H.shape[0]) @ H])
    if norm == "l1":
        #print(V - H @ H.transpose() @ V)
        return np.sum(np.abs(V - H @ D @ H.transpose() @ V))
    if norm == "l2":
        #print(V - H @ H.transpose() @ V)
        return np.sqrt(np.sum((V - H @ D @ H.transpose() @ V)**2))
    return np.linalg.norm(V - H @ D @ H.transpose() @ V, ord=norm)

def frac_ep_cost(A, H, norm='l2'):
    return frac_kmeans_cost_function(A @ H, H, norm=norm)

def kmeans_cost_function(V, colors, norm="l2"):
    H = indicator_matrix_from_colors(colors, normalized=True)
    if norm == "l1":
        #print(V - H @ H.transpose() @ V)
        return np.sum(np.abs(V - H @ H.transpose() @ V))
    if norm == "l2":
        #print(V - H @ H.transpose() @ V)
        return np.sum(np.power(V - H @ H.transpose() @ V,2))
    return np.linalg.norm(V - H @ H.transpose() @ V, ord=norm)

def ep_cost_function(A, colors, norm='fro'):
    return kmeans_cost_function(A @ indicator_matrix_from_colors(colors), colors, norm=norm)

def deep_ep_cost_function(A, H, alpha=1.0, depth=20, norm='l1'):
    X = copy.deepcopy(A)
    sum = 0
    for it in range(1, depth+1):
        X = X / np.linalg.norm(X)
        sum += alpha**depth * frac_ep_cost(X, H, norm=norm)
        X = X @ A
    
    return sum / depth

def frac_WL_multiter(A, k, min_iter=2, n_iter=10, verbose=0, return_hard_coloring=False, update='mean', keep_clustering=False, early_stopping=0.1, lr=0.5, fit_function='bgm', eval_func=frac_ep_cost, trials=10):
    trials = [(H, eval_func(A, H)) for H in [frac_WL(A,k, min_iter=min_iter, n_iter=n_iter, verbose=verbose, update=update, keep_clustering=keep_clustering, early_stopping=early_stopping, lr=lr, fit_function=fit_function) for t in range(trials)]]
    return sorted(trials, key=lambda x: x[1])[0][0]

def kmedians1d(x, k, verbose=0):
    sort_idx = np.argsort(x)
    reverse = np.argsort(sort_idx)
    x = x[sort_idx]
    if verbose >= 1:
        print(x)
    CC = np.zeros((len(x),len(x)))
    for i in range(CC.shape[0]):
        for j in range(i,CC.shape[1]):
            mu = np.mean([x[l] for l in range(i,j+1)])
            CC[i,j] = np.sum([np.abs(x[l]-mu) for l in range(i,j+1)])
    D = np.zeros_like(CC)
    T = np.zeros_like(CC, dtype=int)
    for i in range(CC.shape[1]):
        D[0][i] = CC[0, i]
        T[i][i] = i

    for i in range(1, k):
        for m in range(i+1,CC.shape[1]):
            idx = np.argmin([D[i-1][j-1] + CC[j,m] for j in range(i,m+1)]) + i
            #print(idx, D[i-1, idx-1], CC[idx, m])
            D[i][m] = D[i-1][idx-1] + CC[idx,m]
            T[i][m] = int(idx) 

    if verbose >= 1:
        print(D, T)
    col = []
    k = k-1
    old_idx = len(x)-1
    cuts = [len(x)]
    for i in range(k):
        idx = T[k-i][old_idx]
        cuts.append(idx)
        old_idx = idx -1
    cuts.append(0)
    cuts.reverse()
    if verbose >= 1:
        print(cuts)
    for i in range(len(cuts)-1):
        col.extend([i]*(cuts[i+1]-cuts[i]))
        if verbose >= 1:
            print(col)
    if verbose >= 1:
        print(sort_idx)
        print(reverse)
        print(np.array(col)[reverse])
    return np.array(col)[reverse]


def colors_from_indicator_matrix(H):
    return np.argmax(H, axis=1)

def assignment_acc(col, target):
    unique_col = np.unique(col)
    unique_target, counts_target = np.unique(target, return_counts=True)
    assert len(unique_col) == len(unique_target)
    cost_matrix = np.zeros((len(unique_col),len(unique_col)))
    for idx_i, i in enumerate(unique_col):
        for idx_j, j in enumerate(unique_target):
            cost_matrix[idx_i, idx_j] = np.sum(np.multiply(col == i, target == j))

    #print(cost_matrix)
    cost_matrix = np.outer(counts_target,np.ones_like(cost_matrix.shape[1]))-cost_matrix
    #print(cost_matrix)
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    return 1 - np.sum(cost_matrix[row_ind, col_ind])/len(col)

def evaluate_algorithm(A, k, algorithm, iterations=3, eval_function=ep_cost_function, size_H=0):
    colors = np.zeros((iterations, A.shape[0], k))
    eval = np.zeros((iterations))
    times = np.zeros((iterations))


    for it in range(iterations):
        np.random.seed()
        start = time.process_time()
        try:
            colors[it] = np.array(algorithm(A, k))
        except ValueError as err:
            colors[it] = np.concatenate([np.ones((A.shape[0], 1)),np.zeros((A.shape[0], k-1))], axis=1)
            print(err)
        except cvxpy.error.SolverError:
            colors[it] = np.zeros((A.shape[0], k))
        
        times[it] = time.process_time() - start
        eval[it] = eval_function(A,colors[it])

    if iterations == 0: #Debugging
        return 0,0,np.zeros((A.shape[0], size_H))
    
    min_idx = np.argmin(eval)
    # print(eval, eval[min_idx])
    
    return eval[min_idx], np.mean(times), np.concatenate([np.array(colors[min_idx]),np.zeros((A.shape[0], size_H-k))], axis=1) if size_H > k else np.array(colors[min_idx])



def role2vec(A, k, precomputed_embedding=None, return_embedding=False):
    if precomputed_embedding is not None:
        embedding = precomputed_embedding
    else:
        emb = Role2Vec()
        emb.fit(nx.from_numpy_array(A))
        embedding = emb.get_embedding()
    if return_embedding:
        return embedding
    return indicator_matrix_from_colors(KMeans(k, n_init=5).fit_predict(embedding), padding=k)

def node2vec(A, k, precomputed_embedding=None, return_embedding=False):
    if precomputed_embedding is not None:
        embedding = precomputed_embedding
    else:
        emb = Node2Vec()
        emb.fit(nx.from_numpy_array(A))
        embedding = emb.get_embedding()
    if return_embedding:
        return embedding
    return indicator_matrix_from_colors(KMeans(k, n_init=5).fit_predict(embedding), padding=k)

def appr_EP_by_dom_EV(A, k, its=20, return_X=False, precomputed_embedding=None):
    if precomputed_embedding is not None:
        X = precomputed_embedding
    else:
        X = np.ones((A.shape[1], 1))
        for i in range(its):
            X = A @ X
            X = X / np.linalg.norm(X)
        X = X.reshape((A.shape[1],))
    
    if return_X:
        return indicator_matrix_from_colors(kmedians1d(X, k)), X
    return indicator_matrix_from_colors(kmedians1d(X, k), padding=k)


def save_experiment(root: str, evals, times, colors, graphs=None):
    if not os.path.exists(root):
        os.makedirs(root)
    np.save(root + '/evals', evals)
    np.save(root + '/times', times)
    np.save(root + '/colors', colors)
    if graphs is not None:
        np.save(root + '/graphs', graphs)




def sample_csEP(nodes_per_group, links_to_other_group, multi_graph_okay=False):
    group_offset = []
    num_nodes = 0
    links_needed = np.zeros((np.sum(nodes_per_group), len(nodes_per_group)), dtype=int)
    for i in range(len(nodes_per_group)):
        for node in range(nodes_per_group[i]):
            links_needed[node + num_nodes, :] = links_to_other_group[i]
        group_offset.append(num_nodes)
        num_nodes += nodes_per_group[i]

    adjacency_matrix = np.zeros((np.sum(nodes_per_group), np.sum(nodes_per_group)), dtype=int)

    for group in range(len(nodes_per_group)):
        for node in range(group_offset[group], group_offset[group] + nodes_per_group[group]):
            for other_group in range(len(links_to_other_group[group])):
                chosen_links = np.random.permutation([(node, i) for i in range(group_offset[other_group], group_offset[other_group] + nodes_per_group[other_group]) if links_needed[i, group] > 0 and node != i])
                for link in list(chosen_links[:links_needed[node, other_group]]):
                    adjacency_matrix[link[0], link[1]] = 1
                    adjacency_matrix[link[1], link[0]] = 1
                    links_needed[node, other_group] -= 1
                    links_needed[link[1], group] -= 1
                while links_needed[node, other_group] > 0:
                    chosen_links = np.random.permutation([(node, i) for i in range(group_offset[other_group], group_offset[other_group] + nodes_per_group[other_group]) if links_needed[i, group] > 0])
                    link = chosen_links[0]
                    adjacency_matrix[link[0], link[1]] += 1
                    adjacency_matrix[link[1], link[0]] += 1 if link[0] != link[1] else 0
                    links_needed[node, other_group] -= 1
                    links_needed[link[1], group] -= 1 if node != link[1] else 0
                    
    
    group_offset.append(np.sum(nodes_per_group))
    #print(links_needed)
    #print([[np.sum([adjacency_matrix[i,j] for j in range(group_offset[group], group_offset[group+1])]) for group in range(len(nodes_per_group))] for i in range(len(adjacency_matrix))])
    links_to_other_group_per_node = []
    colors = []
    for i,g in enumerate(nodes_per_group):
        links_to_other_group_per_node.extend([links_to_other_group[i]] * g)
        colors.extend([i]*g)
    #print(links_to_other_group_per_node)
    assert np.all(adjacency_matrix @ indicator_matrix_from_colors(colors) == links_to_other_group_per_node)
    adjacency_matrix = make_simple_graph(adjacency_matrix, nodes_per_group, links_to_other_group, group_offset)
    #print([[np.sum([adjacency_matrix[i,j] for j in range(group_offset[group], group_offset[group+1])]) for group in range(len(nodes_per_group))] for i in range(len(adjacency_matrix))])
    #print(links_to_other_group_per_node)
    assert np.all(adjacency_matrix @ indicator_matrix_from_colors(colors) == links_to_other_group_per_node)
    assert len([(i,j) for i in range(len(adjacency_matrix)) for j in range(len(adjacency_matrix)) if adjacency_matrix[i,j] > 1 or (i==j and adjacency_matrix[i,j] > 0)]) == 0
    
    return adjacency_matrix.copy()

def make_simple_graph(adjacency_matrix, nodes_per_group, links_to_other_group, group_offset):
    #print(adjacency_matrix)
    diagonal_entries_to_fix = [(i,i) for i in range(len(adjacency_matrix)) if adjacency_matrix[i,i] > 0]
    for edge in diagonal_entries_to_fix:
        for i in range(adjacency_matrix[edge[0],edge[1]]):
            if adjacency_matrix[edge[0],edge[1]] >= 2:
                group_x = [g for g in range(len(nodes_per_group)) if edge[0] >= group_offset[g] and edge[0] < group_offset[g+1]][0]
                group_y = [g for g in range(len(nodes_per_group)) if edge[1] >= group_offset[g] and edge[1] < group_offset[g+1]][0]

                candidates_x = [n for n in range(group_offset[group_x], group_offset[group_x + 1]) if n != edge[1] and adjacency_matrix[edge[1], n] < 1]
                candidates_y = [n for n in range(group_offset[group_y], group_offset[group_y + 1]) if n != edge[0] and adjacency_matrix[n, edge[0]] < 1]

                choice = np.random.permutation([(x,y) for x in candidates_x for y in candidates_y if adjacency_matrix[x,y] >= 1 and x != y])[0]
                #print(edge, choice)
                adjacency_matrix[edge[0], choice[1]] = 1
                adjacency_matrix[choice[1], edge[0]] = 1
                adjacency_matrix[choice[0], edge[1]] = 1
                adjacency_matrix[edge[1], choice[0]] = 1
                
                adjacency_matrix[edge[0], edge[1]] -= 1
                adjacency_matrix[edge[1], edge[0]] -= 1 
                adjacency_matrix[choice[0], choice[1]] -= 1
                adjacency_matrix[choice[1], choice[0]] -= 1
                #print(adjacency_matrix)
            elif adjacency_matrix[edge[0],edge[1]] == 1:
                group = [g for g in range(len(nodes_per_group)) if edge[0] >= group_offset[g] and edge[0] < group_offset[g+1]][0]
                candidates = [n for n in range(group_offset[group], group_offset[group + 1]) if n != edge[0]]
                choice = np.random.permutation([(x,x) for x in candidates if adjacency_matrix[x,x] >= 1])[0]
                #print(edge, choice)
                adjacency_matrix[edge[0], choice[1]] += 1
                adjacency_matrix[choice[1], edge[0]] += 1
                adjacency_matrix[edge[1], edge[0]] -= 1 
                adjacency_matrix[choice[0], choice[1]] -= 1
                #print(adjacency_matrix)
            else:
                continue


            

    edges_to_fix =[(i,j) for i in range(len(adjacency_matrix)) for j in range(i+1,len(adjacency_matrix)) if adjacency_matrix[i,j] > 1 and i != j]
   

    for edge in edges_to_fix:
        for i in range(adjacency_matrix[edge[0],edge[1]]-1):
                group_x = [g for g in range(len(nodes_per_group)) if edge[0] >= group_offset[g] and edge[0] < group_offset[g+1]][0]
                group_y = [g for g in range(len(nodes_per_group)) if edge[1] >= group_offset[g] and edge[1] < group_offset[g+1]][0]

                candidates_x = [n for n in range(group_offset[group_x], group_offset[group_x + 1]) if n != edge[1] and adjacency_matrix[edge[1], n] < 1]
                candidates_y = [n for n in range(group_offset[group_y], group_offset[group_y + 1]) if n != edge[0] and adjacency_matrix[n, edge[0]] < 1]

                choice = np.random.permutation([(x,y) for x in candidates_x for y in candidates_y if adjacency_matrix[x,y] >= 1 and x != y])[0]
                #print(edge, choice)
                adjacency_matrix[edge[0], choice[1]] = 1
                adjacency_matrix[choice[1], edge[0]] = 1
                adjacency_matrix[choice[0], edge[1]] = 1
                adjacency_matrix[edge[1], choice[0]] = 1
                
                adjacency_matrix[edge[0], edge[1]] -= 1
                adjacency_matrix[edge[1], edge[0]] -= 1 
                adjacency_matrix[choice[0], choice[1]] -= 1
                adjacency_matrix[choice[1], choice[0]] -= 1
                #print(adjacency_matrix)
                if adjacency_matrix[choice[0], choice[1]] > 0:
                    return make_simple_graph(adjacency_matrix, nodes_per_group, links_to_other_group)

    

    return adjacency_matrix


def sample_uniform_csEP(size_classes, num_classes, max_edges=0):
    if max_edges == 0:
        max_edges = size_classes
    A_pi = np.zeros(shape=(num_classes, num_classes), dtype=int)
    for i in range(num_classes):
        for j in range(num_classes):
            A_pi[i,j] = np.random.randint(0, min(size_classes - 1, max_edges))
            A_pi[j,i] = A_pi[i,j]


    return sample_csEP([size_classes] * num_classes, A_pi)