import networkx as nx
import numpy as np
import time
import networks
import numba
import random
from numba import types, typed
from typing import List
import scipy as sp
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings

import random_walk



warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)


def exact_ht(G, u, v, sparse=False):
    n = len(G)

    if sparse:
        M = nx.to_scipy_sparse_array(G)
        M = M > 0

        row_sums = np.array(M.sum(axis=1)).flatten()
        row_indices, col_indices = M.nonzero()
        M.astype(float)
        M.data = M.data / row_sums[row_indices]

        I = sp.sparse.eye(n)
        A = I - M
        A = A.tolil()
        A[v] = 0
        A[v, v] = 1
        A = A.tocsr()

        b = np.ones(n)
        b[v] = 0

        res = sp.sparse.linalg.spsolve(A, b)
        h = res[u]
        return h

    else:
        M = nx.to_numpy_array(G)
        M = M > 0
        M = M / np.sum(M, axis=1)[:, None]
        A = np.eye(n) - M
        A[v, :] = 0
        A[v, v] = 1
        b = np.ones(n)
        b[v] = 0
        res = np.linalg.solve(A, b)
        h = res[u]
        return h


"""
def ht_via_cutoff(G, u, v, ell=50, r=100):
    ht = 0
    for i in range(ell):
        # Execute r random walks of length i from v and count how many end at u and how many end at v

        count_u = 0
        count_v = 0

        for _ in range(r):
            current = v
            for _ in range(i):
                # Pick a random neighbor of current
                neighbors = list(G.neighbors(current))
                if len(neighbors) == 0:
                    assert False, "Assuming all degrees are at least 1"
                current = random.choice(neighbors)

            if current == u:
                count_u += 1
            if current == v:
                count_v += 1

        ht += ((count_v - count_u) * (2 * len(G.edges))) / (r * G.degree(v))

    return ht
"""


def ht_via_cutoff(G, u, v, num_random_walks=1000, max_len=1000):
    cum_deg, edges = to_numba_graph(G)
    count_no_deletions(u, v, 0, 0, cum_deg, edges)

    start_time = time.time()
    count, num_samples = count_no_deletions(u, v, num_random_walks, max_len,
            cum_deg, edges)
    h = count * 2 * len(G.edges) / G.degree(v) / num_random_walks
    walltime = time.time() - start_time

    return h, num_samples, walltime


@numba.jit(nopython=True)
def count_no_deletions(u:int, v:int, num_random_walks:int, max_len:int,
        cum_deg:List[int], edges:List[int]):

    count = 0
    num_samples = 0
    for start, sgn in (u, -1), (v, 1):
        for _ in range(num_random_walks):
            w = start
            for t in range(max_len):
                if w == v:
                    count += sgn
                num_samples += 1
                ix = random.randint(cum_deg[w], cum_deg[w + 1] - 1)
                w = edges[ix]

    return count, num_samples


"""
def sampling_ht(G, u, v, num_random_walks=10000):
    ts = []
    num_samples = 0

    for _ in range(num_random_walks):
        rnd_walk = random_walk.random_walk(G, u)
        t = 0

        while True:
            if next(rnd_walk) == v:
                ts.append(t)
                break
            else:
                t += 1
                num_samples += 1

    return np.mean(ts), num_samples
"""


def sampling_ht(G, u, v, num_random_walks=10000):
    cum_deg, edges = to_numba_graph(G)
    mean_arrival_time(u, v, 0, cum_deg, edges)

    start_time = time.time()
    h, num_samples = mean_arrival_time(u, v, num_random_walks, cum_deg, edges)
    walltime = time.time() - start_time

    return h, num_samples, walltime


@numba.jit(nopython=True)
def mean_arrival_time(u:int, v:int, num_random_walks:int, cum_deg:List[int],
        edges:List[int]):
    h = 0
    num_samples = 0

    for _ in range(num_random_walks):
        w = u
        t = 0
        while w != v:
            t += 1
            ix = random.randint(cum_deg[w], cum_deg[w + 1] - 1)
            w = edges[ix]

        num_samples += t
        h += t / num_random_walks

    return h, num_samples


def to_numba_graph(G):
    cd = 0
    cum_deg = [cd]
    edges = []
    for x in G:
        for y in G.neighbors(x):
            edges.append(y)
            cd += 1
        cum_deg.append(cd)
    return cum_deg, edges


def estimate_local_ht(G, u, v, num_random_walks=10000, max_len=100000,
        verbose=False):
    # preparing graph representation and compiling once: we do not count this towards the running time
    cum_deg, edges = to_numba_graph(G)
    count_with_deletions(u, v, 0, 0, cum_deg, edges, False)

    start_time = time.time()
    count, num_samples = count_with_deletions(u, v, num_random_walks, max_len,
            cum_deg, edges, verbose=verbose)

    h = count * 2 * len(G.edges) / G.degree(v) / num_random_walks
    walltime = time.time() - start_time

    return h, num_samples, walltime


@numba.jit(nopython=True)
def count_with_deletions(u:int, v:int, num_random_walks:int, max_len:int,
        cum_deg:List[int], edges:List[int], verbose):

    u_rnd_walks = [u for _ in range(num_random_walks)]
    v_rnd_walks = [v for _ in range(num_random_walks)]

    u_alive = set(range(num_random_walks))
    v_alive = set(range(num_random_walks))

    count = 0
    num_samples = 0
    for t in range(max_len):
        u_pos_list = [typed.List.empty_list(types.int64)]
        u_pos = {-1: 0}

        for u_ix in u_alive:
            w = u_rnd_walks[u_ix]
            if w == v:
                count -= 1

            num_samples += 1
            ix = random.randint(cum_deg[w], cum_deg[w + 1] - 1)
            w = edges[ix]
            u_rnd_walks[u_ix] = w

            if w not in u_pos:
                u_pos[w] = len(u_pos_list)
                u_pos_list.append(typed.List.empty_list(types.int64))
            u_pos_list[u_pos[w]].append(u_ix)

        for v_ix in list(v_alive):
            w = v_rnd_walks[v_ix]
            if w == v:
                count += 1

            num_samples += 1
            ix = random.randint(cum_deg[w], cum_deg[w + 1] - 1)
            w = edges[ix]
            v_rnd_walks[v_ix] = w

            if w in u_pos:
                coll = u_pos_list[u_pos[w]]
                if len(coll) > 0:
                    u_ix = coll.pop()
                    u_alive.remove(u_ix)
                    v_alive.remove(v_ix)

        if len(u_alive) == 0:
            break

    else:
        if verbose:
            print("done (but " + str(len(u_alive)) + " random walks did not collide)")

    return count, num_samples

