import networkx as nx
import numpy as np
import time
import networks
import random_walk



def exact_eff_res(G, u, v):
    A = nx.to_numpy_array(G)
    L = np.diag(np.sum(A, axis=1)) - A
    Linv = np.linalg.pinv(L)
    true_eff_res = Linv[u, u] - Linv[u, v] - Linv[v, u] + Linv[v, v]
    return true_eff_res


def estimate_local_eff_res(G, u, v, delete=False, num_random_walks=1000,
        max_len=100, verbose=False):
    deg_u = G.degree(u)
    deg_v = G.degree(v)

    u_rnd_walks = [random_walk.random_walk(G, u) for _ in range(num_random_walks)]
    v_rnd_walks = [random_walk.random_walk(G, v) for _ in range(num_random_walks)]

    empty = []
    eff_res = 0
    num_samples = 0
    for i in range(max_len):
        u_distr = {}
        v_distr = {}

        for ix, u_rnd_walk in enumerate(u_rnd_walks):
            w = next(u_rnd_walk)
            num_samples += 1
            if w not in u_distr:
                u_distr[w] = []
            u_distr[w].append(ix)

        for ix, v_rnd_walk in enumerate(v_rnd_walks):
            w = next(v_rnd_walk)
            num_samples += 1
            if w not in v_distr:
                v_distr[w] = []
            v_distr[w].append(ix)

        xuv = len(u_distr.get(v, empty))
        xuu = len(u_distr.get(u, empty))
        xvu = len(v_distr.get(u, empty))
        xvv = len(v_distr.get(v, empty))

        y = (xuu / deg_u - xuv / deg_v - xvu / deg_u + xvv / deg_v) / num_random_walks
        eff_res += y
        if verbose:
            print(f"len={i: >4}: {eff_res:.5f} ({len(u_rnd_walks)})")

        if delete:
            u_rem_ix = set()
            v_rem_ix = set()
            for w, u_ixs in u_distr.items():
                v_ixs = v_distr.get(w, [])
                for u_ix, v_ix in zip(u_ixs, v_ixs):
                    u_rem_ix.add(u_ix)
                    v_rem_ix.add(v_ix)

            for u_ix in sorted(u_rem_ix)[::-1]:
                del u_rnd_walks[u_ix]
            for v_ix in sorted(v_rem_ix)[::-1]:
                del v_rnd_walks[v_ix]

        assert(len(u_rnd_walks) == len(v_rnd_walks))
        if len(u_rnd_walks) == 0:
            break

    return eff_res, num_samples

