import random
import numpy as np
from collections import deque


def random_walk(G, source_node, walk_length):
    """
    simple random walk with length of 'walk_length', and starting node of 'source_node'
    """
    walk = [source_node]

    while len(walk) < walk_length:
        current_node = walk[-1]
        neighbors = list(G.neighbors(current_node))
        next_node = random.choice(neighbors)

        walk.append(next_node)
    return walk

def random_walk_plus(G, source_node, walk_length, similarity_matrix, degree=1,prob=0.6):
    """
    random walk with the bias of similarity matrix (signals)
    args:
        G (nx.Graph): input graph
        source_node (int): starting node id
        walk_length (int): length of walk
        similarity_matrix (np.array): matrix keeping node signal relations 
        p (float [0,1]): probability of using signal bias
    return:
        walk (list): sequence of node id's
    """
    walk = [source_node]

    while len(walk) < walk_length:
        current_node = walk[-1]
        neighbors = list(G.neighbors(current_node))
        if len(neighbors) > 1:
            coin = np.random.rand()
            if coin < prob:
                weights = get_signal_bias_random(current_node, neighbors, similarity_matrix)
                next_node = random.choices(neighbors, weights=weights, k=1)[0]
            else:
                next_node = random.choice(neighbors)
        else:
            next_node = neighbors[0]
        
        walk.append(next_node)
    return walk

def biased_random_walk(G, source_node, walk_length, p, q):
    """
    A random walk variant, used in 'node2vec: Scalable Feature Learning for Networks' paper
    args:
        G (nx.Graph): input graph
        source_node (int): starting node id
        walk_length (int): length of walk
        p (float): return parameter
        q (float): in-out parameter
    return:
        walk (list): sequence of nodes
    """
    walk = [source_node]

    while len(walk) < walk_length:
        current_node = walk[-1]
        neighbors = list(G.neighbors(current_node))

        if len(walk) == 1:
            next_node = random.choice(neighbors)
        
        else:
            probabilities = []
            prev_node = walk[-2]
            for neighbor in neighbors:
                if neighbor == prev_node:
                    probabilities.append(1/p)
                elif G.has_edge(prev_node, neighbor):
                    probabilities.append(1)
                else:
                    probabilities.append(1/q)

            probabilities = [p/sum(probabilities) for p in probabilities]
            next_node = random.choices(neighbors, weights=probabilities, k=1)[0]
        walk.append(next_node)
    return walk

def biased_random_walk_plus(G, source_node, walk_length, p, q, similarity_matrix, prob=0.5, degree=1):
    """
    biased random walk with the bias of similarity matrix (signals)
    args:
        G (nx.Graph): input graph
        source_node (int): starting node id
        walk_length (int): length of walk
        p (float): return parameter
        q (float): in-out parameter
        similarity_matrix (np.array): matrix keeping node signal relations 
        prob (float [0,1]): probability of using signal bias
        degree (1 or 2): degree of temporal bias (1hop or 2hop)
    return:
        walk (list): sequence of node id's
    """
    walk = [source_node]

    while len(walk) < walk_length:
        current_node = walk[-1]
        neighbors = list(G.neighbors(current_node))

        if len(neighbors) == 1 or len(walk) == 1:
            next_node = random.choice(neighbors)

        else:
            probabilities = []
            prev_node = walk[-2]
            out_node_idxs = []
            for i,neighbor in enumerate(neighbors):
                if neighbor == prev_node:
                    probabilities.append(1/p)
                elif G.has_edge(prev_node, neighbor):
                    probabilities.append(1)
                else:
                    probabilities.append(1/q)
                    out_node_idxs.append(i)
            
            coin = np.random.rand()
            if coin < prob:
                if degree == 1:
                    weights = get_signal_bias_brn(current_node, neighbors, similarity_matrix, out_node_idxs)
                else:
                    weights = get_signal_bias_brn2(G, current_node, neighbors, similarity_matrix, out_node_idxs)
                probabilities = np.array(probabilities) * weights
                probabilities = probabilities / sum(probabilities)
            else:
                probabilities = [p/sum(probabilities) for p in probabilities]
            
            next_node = random.choices(neighbors, weights=probabilities, k=1)[0]

        walk.append(next_node)
    return walk
        

def get_signal_bias_random(current_node, candidate_nodes, similarity_matrix):
    """
    returns a vector of bias weights for random walk
    """
    noise = np.random.uniform(0, 0.05, size=len(candidate_nodes))
    signal_bias = np.clip(similarity_matrix[current_node][candidate_nodes], a_min=0, a_max=None) + noise
    signal_bias /= signal_bias.sum()
    return signal_bias


def get_signal_bias_brn(current_node, candidate_nodes, similarity_matrix, out_node_idxs):
    """
    returns a vector of 1 hop bias weights for biased random walk 
    """
    weights = np.ones((len(candidate_nodes)))
    out_nodes = [candidate_nodes[i] for i in out_node_idxs]
    out_weights = np.clip(similarity_matrix[current_node][out_nodes], a_min=0.05, a_max=None) 
    out_weights = (out_weights / out_weights.sum()) * len(out_weights)
    weights[out_node_idxs] *= out_weights
    return weights

def get_signal_bias_brn2(G, current_node, candidate_nodes, similarity_matrix, out_node_idxs):
    """
    returns a vector of 2 hop bias weights for biased random walk
    """
    weights = np.ones((len(candidate_nodes)))
    if len(out_node_idxs) == 0:
        return weights
    
    noise = np.random.uniform(0,0.5,size=(len(out_node_idxs)))
    out_nodes = [candidate_nodes[i] for i in out_node_idxs]
    out_weigths = np.zeros(len(out_nodes))

    for i, n2 in enumerate(out_nodes):
        out_weigths[i] = extract_info(G, current_node, n2, similarity_matrix)
    out_weigths += noise
    if out_weigths.sum() > 0:
        out_weigths = (out_weigths / out_weigths.sum()) * len(out_weigths)
    else:
        out_weigths = np.ones(len(out_weigths))
    weights[out_node_idxs] *= out_weigths
    return weights

def extract_info(G, n1, n2, sim_matrix,l=0.5):
    """
    extracts information from 2hop neighbors
    """
    ns1 = set(G.neighbors(n1))
    ns2 = set(G.neighbors(n2)) - {n1}
    diff = list(ns2.difference(ns1))

    result = np.clip(sim_matrix[n1,n2], a_min=0, a_max=None)
    if len(diff) > 0:
        second_degree_weights = np.clip(sim_matrix[n1, diff], a_min=0, a_max=None).mean()
        result += l*second_degree_weights

    return result