import numba
import numpy as np
import random
import copy

"""
Computes node embeddings from p-norm flow diffusion on an undirected unweighted
graph G. Extentions to weighted graphs are straightforward.
Inputs:
            G - Adjacency list representation of graph.
                Node indices must start from 1 and end with n, where n is the
                number of nodes in G.
      seedset - Dictionary specifying seed node(s) and seed mass as
                (key, value) pairs.
            p - Specifies the p-norm in primal p-norm flow objective.
           mu - Smoothing parameter that smoothes the p-norm objective.
                Setting mu=0 (default) suffices for almost all practical
                purposes.
    max_iters - Maximum number of passes for Random Permutation Coordinate
                Minimization. A single pass goes over all nodes that violate
                KKT conditions.
      epsilon - Tolerance on the maximum excess mass on nodes (which is
                equivalent to the maximum primal infeasibility).
                Diffusion process terminates whenever the excess mass is no
                greater than epsilon on all nodes.
       cm_tol - Tolerance in the approximate coordinate minimization step
                (required only for p > 2).
                Approximate coordinate minimization in this diffusion setting
                is equivalent to performing inexact line-seach on coordinate
                descent stepsizes.
Returns:
            x - Node embeddings.
                Applying sweepcut on x produces the final output cluster.
"""

@numba.njit(cache=True, locals={'best_cut': numba.int64[::1]})
def cut(num_all_nodes, adjmat, degrees, x_indices, x_vals):
    vol_graph = np.sum(degrees)
    hierarchies = len(x_indices)
    assert hierarchies > 0
    
    num_nodes = len(x_indices[0])
    best_cuts = [[np.zeros(0, dtype=np.int64) for _ in range(num_nodes)] for _ in range(hierarchies)]

    for i in range(num_nodes):
        for h in range(hierarchies):
            assert x_indices[h][i].shape == x_vals[h][i].shape
            
            num_nonzero = len(x_vals[h][i])
            indices = x_indices[h][i][np.argsort(x_vals[h][i])[::-1]][0: min(num_nonzero, num_all_nodes-1)]
            best_cut = np.zeros(0, dtype=np.int64)
            best_conductance = 1.
            cur_cut = []
            cur_vol = 0
            cut_size = 0

            for v in indices:
                cur_cut.append(v)
                cur_vol += degrees[v]
                neighbors = adjmat[v, 0:degrees[v]]

                for n in neighbors:
                    if n in cur_cut:
                        cut_size -= 1
                    else:
                        cut_size += 1
                
                cur_conductance = cut_size / min(cur_vol, vol_graph - cur_vol)
                if cur_conductance < best_conductance:
                    best_cut = np.copy(np.array(cur_cut, dtype=np.int64))
                    best_conductance = cur_conductance
            
            assert best_cut.shape[0] <= x_indices[h][i].shape[0]
            
            best_cuts[h][i] = best_cut

    return best_cuts



@numba.njit(cache=True, locals={'mass': numba.float64[::1]})
def pnormdiffusion(adjmat, 
                   degree, 
                   nv, 
                   nodes, 
                   delta, 
                   p, 
                   mu: float = 0, max_iters: int = 50,
                   epsilon=1e-2, cm_tol=1e-3, top_k: int = None, steps: int = 1, 
                   limit_type: str = 'degree'):
    
    hierarchies = len(delta)
    x_js = [[np.zeros(0, dtype=np.int64) for _ in range(len(nodes))] for _ in range(hierarchies)]
    x_vals = [[np.zeros(0, dtype=np.float64) for _ in range(len(nodes))] for _ in range(hierarchies)]
    mass_js = [[np.zeros(0, dtype=np.int64) for _ in range(len(nodes))] for _ in range(hierarchies)]
    mass_vals = [[np.zeros(0, dtype=np.float64) for _ in range(len(nodes))] for _ in range(hierarchies)]

    if limit_type == 'degree':
        limits = degree.astype(np.float64)
    elif limit_type == 'flat':
        limits = np.mean(degree) * np.ones(nv, dtype=np.float64)
    elif limit_type == 'log_degree':
        limits = np.log(degree) + 1.
    else:
        raise NotImplementedError

    for i, seed_node in enumerate(nodes):
        for h in range(hierarchies):

            if h == 0:
                mass = np.zeros(nv, dtype=np.float64)
                mass[seed_node] = delta[0] * degree[seed_node]

            # print(mass[np.where(mass < 0)[0]])
            # print(mass.sum())
            update = np.float(1) if h == 0 else (delta[h] / delta[h-1])

            if p == 2:
                x_j, x_val, mass_j, mass_val = l2opt(mass, update, adjmat, degree, max_iters, epsilon, limits)
            elif p > 2:
                x_j, x_val, mass_j, mass_val = lpopt(mass, update, adjmat, degree, max_iters, epsilon, p, mu,
                                                     cm_tol, limits)
            else:
                raise Exception("p should be >= 2.")

            x_js[h][i] = x_j
            x_vals[h][i] = x_val
            mass_js[h][i] = mass_j
            mass_vals[h][i] = mass_val

            if top_k is not None:
                top_k_ind = np.argsort(mass_vals[h][i])[-top_k:]
                mass_js[h][i] = mass_js[h][i][top_k_ind]
                mass_vals[h][i] = mass_vals[h][i][top_k_ind]

        if i % 100 == 1:
            print(int(i / nodes.shape[0] * 100), "% done")

    return x_js, x_vals, mass_js, mass_vals


@numba.njit(cache=True, locals={'push': numba.float64, 'd': numba.float64})
def l2opt(mass, mass_update, adjmat, degree, max_iters, epsilon, limits):
    # x = {}
    x = np.zeros_like(mass, dtype=np.float64)
    mass *= mass_update
    for i in numba.prange(max_iters):
        T = np.where(mass > limits + epsilon)[0]

        if len(T) == 0:
            break

        np.random.shuffle(T)

        for v in T:
            push = (mass[v] - limits[v]) / degree[v]
            x[v] += push

            mass[v] = limits[v]

            for u in adjmat[v, 0:degree[v]]:
                mass[u] += push
    return np.where(x > 0)[0].astype(np.int64), x[np.where(x > 0)[0].astype(np.int64)], np.where(mass > 0)[0].astype(
        np.int64), mass[np.where(mass > 0)[0]]


@numba.njit(cache=True, locals={'x': numba.float64[::1], 'x_v_prev': numba.float64,
                                'd': numba.float64, 'q': numba.float64})
def lpopt(mass, mass_update, adjmat, degree, max_iters, epsilon, p, mu, cm_tol, limits):
    # x = {}
    x = np.zeros_like(mass, dtype=np.float64)
    q = p / (p - 1)
    mass *= mass_update
    mass0 = np.copy(mass)
    for i in numba.prange(max_iters):
        T = np.where(mass > limits + epsilon)[0]
        if T.size == 0:
            break
        np.random.shuffle(T)

        for v in T:
            x_v_prev = x[v]

            x, mass = push_node_v(x, mass, v, adjmat, degree, mass0, q, mu, cm_tol, limits)

            for u in adjmat[v, 0:degree[v]]:
                mass = update_mass_u(x, mass, u, v, x_v_prev, q, mu)

    return np.where(x > 0)[0].astype(np.int64), x[np.where(x > 0)[0].astype(np.int64)], np.where(mass > 0)[0].astype(
        np.int64), mass[np.where(mass > 0)[0]]


"""
Pushes out the excess mass on node v to its neighbors.
This is done by simply increasing x[v], the incumbent node embedding for node v.
"""


@numba.njit(cache=True, locals={'L': numba.float64, 'U': numba.float64, 'M': numba.float64,
                                'tol': numba.float64})
def push_node_v(x, mass, v, adjmat, degree, mass0, q, mu, cm_tol, limits):
    L = x[v]
    U = L + 1
    while compute_mass_v(x, v, U, adjmat, degree, mass0, q, mu) > limits[v]:
        L = U
        U *= 2

    tol = max(cm_tol, 2 * np.finfo(U).eps)

    while abs(U - L) > tol:

        M = (L + U) / 2

        if compute_mass_v(x, v, M, adjmat, degree, mass0, q, mu) > limits[v]:
            L = M
        else:
            U = M
            # x[v] = (L + U)/2
    x[v] = L
    mass[v] = compute_mass_v(x, v, x[v], adjmat, degree, mass0, q, mu)
    return x, mass


@numba.njit(cache=True)
def update_mass_u(x, mass, u, v, x_v_prev, q, mu):
    mass[u] += flow_uv(x[u], x_v_prev, q, mu) - flow_uv(x[u], x[v], q, mu)
    return mass


"""
Computes the net mass on node v (i.e. initial mass plus all incoming flows).
"""


@numba.njit(cache=True, locals={'mass_v': numba.float64})
def compute_mass_v(x, v, x_v, adjmat, degree, mass0, q, mu):
    mass_v = mass0[v]
    for u in adjmat[v, 0:degree[v]]:
        mass_v += flow_uv(x[u], x_v, q, mu)
    return mass_v


"""
Computes the amount of flow from node u to node v, given node embeddings x_u for
node u and x_v for node v.
"""


@numba.njit(cache=True)
def flow_uv(x_u, x_v, q, mu):
    return (((x_u - x_v) ** 2 + mu ** 2) ** (q / 2 - 1)) * (x_u - x_v) if mu > 0 else (np.abs(x_u - x_v) ** (
                q - 1)) * np.sign(x_u - x_v)
