import numpy as np
import networkx as nx
from networkx.algorithms import bipartite
import math
from scipy.special import comb

def get_bipartition(G, perm=None):
    U = sorted([n for n,d in G.nodes(data=True) if d['bipartite']==0], key=(lambda x: int(x[1:])))
    V = sorted([n for n,d in G.nodes(data=True) if d['bipartite']==1], key=(lambda x: int(x[1:])))
    if perm is not None:
        V = [V[i] for i in perm]
    return U, V

def get_leftnames(G):
    return [d['name'] for n,d in G.nodes(data=True) if d['bipartite']==0]

def empty_bipartite_graph(n_u, n_v):
    G = nx.Graph()
    G.add_nodes_from(['u'+str(x) for x in range(n_u)], bipartite=0)
    G.add_nodes_from(['v'+str(x) for x in range(n_v)], bipartite=1)
    return G

def random_graph_expected(n_u, n_v, degrees_u):
    '''Build a random bipartite graph with expected offline degrees given by degrees_u'''
    G = empty_bipartite_graph(n_u, n_v)
    for u in range(n_u):
        neighbors = np.flatnonzero(np.random.rand(n_v) < degrees_u[u] / n_v)
        G.add_edges_from([('u'+str(u), 'v'+str(nbr)) for nbr in neighbors])

    return G

def ranking(G, verbose=False, perm=None):
    U, V = get_bipartition(G, perm=perm)
    n_u = len(U)
    n_v = len(V)
    matching_v = -1 * np.ones(n_v, dtype=np.int)
    rank_u = np.arange(n_u)
    np.random.shuffle(rank_u)
    max_val = np.max(rank_u) + 1
    for i, v in enumerate(V):
        neighbors_v = [int(x[1:]) for x in nx.neighbors(G, v)]
        if len(neighbors_v) == 0:
            if verbose:
                print("{} failed, no neighbors".format(v))
            continue
        smallest_rank = np.min(rank_u[neighbors_v])
        if smallest_rank >= max_val:
            if verbose:
                print("{} failed, all neighbors picked".format(v))
            continue
        smallest_rank_neighbor = neighbors_v[np.argmin(rank_u[neighbors_v])]
        matching_v[i] = smallest_rank_neighbor
        rank_u[smallest_rank_neighbor] = max_val

    return matching_v

def random(G, verbose=False, perm=None):
    U, V = get_bipartition(G, perm=perm)
    n_u = len(U)
    n_v = len(V)
    matching_v = -1 * np.ones(n_v, dtype=np.int)
    matched = set()
    for i, v in enumerate(V):
        neighbors_v = [int(x[1:]) for x in nx.neighbors(G, v) if int(x[1:]) not in matched]
        if len(neighbors_v) == 0:
            if verbose:
                print("{} failed, no neighbors".format(v))
            continue
        nbr = np.random.choice(neighbors_v)
        matching_v[i] = nbr
        matched.add(nbr)

    return matching_v

def static_degrees(G, degrees, verbose=False, perm=None):
    U, V = get_bipartition(G, perm=perm)
    n_u = len(U)
    n_v = len(V)
    matching_v = -1 * np.ones(n_v, dtype=np.int)
    rank_u = np.arange(n_u)
    np.random.shuffle(rank_u)
    rank_u = rank_u + (n_u * degrees)
    max_val = np.max(rank_u) + 1
    for i, v in enumerate(V):
        neighbors_v = [int(x[1:]) for x in nx.neighbors(G, v)]
        if len(neighbors_v) == 0:
            if verbose:
                print("{} failed, no neighbors".format(v))
            continue
        smallest_rank = np.min(rank_u[neighbors_v])
        if smallest_rank >= max_val:
            if verbose:
                print("{} failed, all neighbors picked".format(v))
            continue
        smallest_rank_neighbor = neighbors_v[np.argmin(rank_u[neighbors_v])]
        matching_v[i] = smallest_rank_neighbor
        rank_u[smallest_rank_neighbor] = max_val

    return matching_v

def static_with_noise(G, degrees, noise_type='gaussian', noise_param=100, perm=None):
    noisy_degrees = degrees + np.random.normal(0, noise_param, size=degrees.size)
    return static_degrees(G, noisy_degrees, perm=perm)

def static_with_oracle_dict(G, oracle_dict, degrees_u, default=1, perm=None, verbose=False):
    leftnames = get_leftnames(G)
    n_u = len(leftnames)
    pred_degrees = np.ones(n_u, dtype=int) * default
    ctr = 0
    for i, key in enumerate(leftnames):
        if key in oracle_dict:
            pred_degrees[i] = oracle_dict[key]
        else:
            ctr += 1
    if verbose:
        print('New left nodes: {}'.format(ctr))
        print('l1 dist: {}'.format(np.sum(np.abs(pred_degrees - degrees_u))))
        print('linf dist: {}'.format(np.max(np.abs(pred_degrees - degrees_u))))
    return static_degrees(G, pred_degrees, perm=perm)

def max_matching(G, perm=None):
    U, V = get_bipartition(G, perm=perm)
    n_u = len(U)
    n_v = len(V)
    matching_v = -1 * np.ones(n_v, dtype=int)

    components = list(nx.connected_components(G))
    for c in components:
        partial_matching = bipartite.maximum_matching(G.subgraph(c))
        for v in partial_matching:
            if v in V:
                idx = V.index(v)
                u = partial_matching[v]
                matching_v[idx] = int(u[1:])

    return matching_v

def halls_thm_greedy(G, multiround=False, verbose=False):
    # Greedily search for a large subset with small expansion
    U, V = get_bipartition(G)
    degrees_u = np.array([G.degree[u] for u in U]) # true left degrees
    S = set([i for i,j in zip(U,list(degrees_u <= 1)) if j])
    remainingnodes = set([u for u in U if u not in S])
    N_S = set()
    for s in S:
        N_S.update(set(nx.neighbors(G, s)))

    changed = True
    while changed:
        changed = False
        newnodes = set()
        for u in remainingnodes:
            nbrs = set(nx.neighbors(G, u))
            if len(nbrs.difference(N_S)) <= 0:
                # print(G.degree[u])
                S.add(u)
                N_S.update(nbrs)
                newnodes.add(u)
                changed = True
        remainingnodes = remainingnodes.difference(newnodes)
        if not multiround:
            changed = False

    if verbose:
        S_degrees = np.zeros(int(np.max(degrees_u)))
        for u in S:
            S_degrees[int(G.degree[u])] += 1
        print('S degrees: {}'.format(S_degrees[:10]))
        print('|S|: {}'.format(len(S)))
        print('|N(S)|: {}'.format(len(N_S)))
        print('Actual Hall\'s Theorem Greedy: {}'.format(len(U) - (len(S) - len(N_S))))
        print()
    
    return len(U) - (len(S) - len(N_S))

def offline_bound_expected(n, m, degrees_u, dmax=10, verbose=False):
    prod = [] # probability of v1, v2, ..., vi having a deg 1 nbr
    a = [] # prod not counting u
    S = [] # size of S_i
    S.append(np.sum(np.power(1 - degrees_u/m, m))) #S_0
    for i in range(dmax):
        d = i+1
        prod.append(np.prod(1 - d*degrees_u/m * np.power(1 - degrees_u/m, m-1)))
        a.append(prod[i] / (1 - d*degrees_u/m * np.power(1 - degrees_u/m, m-1)))
        incexc = 1 # inclusion-exclusion term
        if d > 1: 
            for j in range(1, d+1):
                if j % 2 == 1:
                    incexc -= comb(d, j) * a[j-1]
                else:
                    incexc += comb(d, j) * a[j-1]
        S.append(comb(m, d) * np.sum(np.power(degrees_u/m, d) * np.power(1 - degrees_u/m, m-d) * incexc))
    N_S = m * (1 - prod[0])
    S_total = np.sum(S)
    if verbose:
        for i, s in enumerate(S):
            print('S_{}: {:.2f}'.format(i, s))
        print('N(S): {:.2f}'.format(N_S))
        print('Expected Degree Hall\'s Theorem Greedy: {:.2f}'.format(n - S_total + N_S))
    return n - S_total + N_S

def offline_bound_expected_asymptotic(degrees, fractions, dmax=10, verbose=False):
    D = len(degrees)
    prod = [] # probability of v1, v2, ..., vi having a deg 1 nbr
    S = [] # size of S_i
    S_0 = 0
    for j in range(D):
        S_0 += fractions[j] * np.exp(-degrees[j])
    S.append(S_0)
    for i in range(dmax):
        d = i+1
        product = 1
        for j in range(D):
            product *= np.exp(-d * degrees[j] * fractions[j] * np.exp(-degrees[j]))
        prod.append(product)
        incexc = 1 # inclusion-exclusion term
        if d > 1: 
            for k in range(1, d+1):
                if k % 2 == 1:
                    incexc -= comb(d, k) * prod[k-1]
                else:
                    incexc += comb(d, k) * prod[k-1]
        size = 0
        for j in range(D):
            size += (1/math.factorial(d)) * fractions[j] * np.power(degrees[j], d) * np.exp(-degrees[j]) * incexc
        S.append(size)
    N_S = (1 - prod[0])
    S_total = np.sum(S)
    if verbose:
        for i, s in enumerate(S):
            print('S_{}: {:.4f}'.format(i, s))
        print('N(S): {:.4f}'.format(N_S))
        print('Expected Degree Hall\'s Theorem Greedy: {:.4f}'.format(1 - S_total + N_S))
    return 1 - S_total + N_S

def simulate_diffeqs(n, m, degrees_u, verbose=False):
    degrees = np.sort(np.unique(degrees_u[degrees_u > 0]))
    counts = [np.count_nonzero(degrees_u == d) for d in degrees]
    D = degrees.size
    T = m
    k = np.zeros(D)
    C = np.zeros(D)
    alpha = np.zeros(D)
    alpha0 = np.zeros(D)
    z = np.zeros(D)
    y = np.zeros(D)

    for i in range(D):
        k[i] = - np.log(1 - degrees[i]/m)
    
    C[0] = np.exp(k[0] * counts[0]) - 1

    alpha[1] = C[0] + np.exp(k[0] * T)
    alpha0[1] = C[0] + 1
    C[1] = np.power(alpha0[1], k[1]/k[0]) * (np.exp(k[1] * counts[1]) - 1)
    for i in range(2,D): #alpha_i(t) and C_i
        if degrees[i] == m:
            continue
        frac = k[i-1]/k[i-2]
        alpha[i] = np.power(alpha[i-1], frac) + C[i - 1] 
        alpha0[i] = np.power(alpha0[i-1], frac) + C[i - 1] 
        C[i] = np.power(alpha0[i], k[i]/k[i-1]) * (np.exp(k[i] * counts[i]) - 1) 
        if alpha[i] == np.inf or C[i] == np.inf: # encountered overflow, all z[i] after will be zero
            break

    for i in range(D):
        if i == 0:
            z[0] = - np.log(C[0] * np.exp(- k[0] * T) + 1)
            y[0] = z[0] / (-k[0])
        elif degrees[i] == m:
            z[i] = 0 # assuming all such offline nodes are matched
            y[i] = 0
        else:
            if alpha[i] == np.inf or C[i]==np.inf: # encountered overflow, all z[i] after will be zero
                z[i:] = 0
                y[i:] = 0
                break
            else:
                z[i] = - np.log(C[i] * np.power(alpha[i], -k[i]/k[i-1]) + 1)
                y[i] = z[i] / (-k[i])
        if verbose and i < 100:
            print('{}: {:2f} / {}'.format(degrees[i], y[i], counts[i]))

    return y

def simulate_diffeqs_asymptotic(degrees, fractions, verbose=False):
    D = len(degrees)
    C = np.zeros(D)
    alpha = np.zeros(D)
    alpha0 = np.zeros(D)
    z = np.zeros(D)
    y = np.zeros(D)

    C[0] = np.exp(fractions[0]) - 1

    alpha[1] = C[0] + np.exp(degrees[0])
    alpha0[1] = C[0] + 1
    C[1] = np.power(alpha0[1], degrees[1]/degrees[0]) * (np.exp(degrees[1] * fractions[1]) - 1)
    for i in range(2,D): #alpha_i(t) and C_i
        frac = degrees[i-1]/degrees[i-2]
        alpha[i] = np.power(alpha[i-1], frac) + C[i - 1] 
        alpha0[i] = np.power(alpha0[i-1], frac) + C[i - 1] 
        C[i] = np.power(alpha0[i], degrees[i]/degrees[i-1]) * (np.exp(degrees[i] * fractions[i]) - 1) 
        if alpha[i] == np.inf or C[i] == np.inf: # encountered overflow, all z[i] after will be zero
            break

    for i in range(D):
        if i == 0:
            z[0] = - np.log(C[0] * np.exp(-degrees[0]) + 1)
            y[0] = z[0] / (-degrees[0])
        else:
            if alpha[i] == np.inf or C[i]==np.inf: # encountered overflow, all z[i] after will be zero
                z[i:] = 0
                y[i:] = 0
                break
            else:
                z[i] = - np.log(C[i] * np.power(alpha[i], -degrees[i]/degrees[i-1]) + 1)
                y[i] = z[i] / (-degrees[i])
        if verbose and i < 20:
            print('{}: {:.6f} / {:.6f}'.format(degrees[i], y[i], fractions[i]))

    return y

def build_graph(n_u, n_v, dist_param, verbose=True, dist_type='zipf', graph_type='fixed'):
    if dist_type == 'zipf-linguist':
        max_deg = n_v / 2
        degrees_u = max_deg / np.arange(1, n_u+1)**dist_param
    elif dist_type == 'import':
        G = dist_param
        U, V = get_bipartition(G)
        degrees_u = np.array([G.degree[u] for u in U])
    else:
        print('Invalid distribution type: {}'.format(dist_type))
        exit(0)
    if verbose:
        print(np.sort(degrees_u))
    if dist_type == 'import':
        pass
    elif graph_type == 'expected':
        G = random_graph_expected(n_u, n_v, degrees_u)
    else:
        exit(0)
    assert(nx.is_bipartite(G))
    return G, degrees_u

def single_run(G, degrees_u, perm=None, oracle_dict=None):
    matchings = {}
    matchings['offline_matching'] = max_matching(G, perm=perm)
    matchings['ranking_matching'] = ranking(G, perm=perm)
    matchings['static_degrees_matching'] = static_degrees(G, degrees_u, perm=perm)
    if oracle_dict is not None:
        matchings['static_oracle_dict_matching'] = static_with_oracle_dict(G, oracle_dict, degrees_u, perm=perm)

    return matchings

def get_matching_size(matching):
    return np.sum(matching != -1)