import numpy as np
from walks import *


def generate_walks(G, number_of_walks_per_node, walk_length, walk_type, similarity_matrix, ratio=0.8, **kwargs):
    """
    Generates walks for input graph

    params:
        G (nx.Graph): input graph 
        number_of_walks_per_node (int)
        walk_length (int)
        walk_type (string): random walk or dfs or biased random walk (if ends with '_plus', it is aided by relations of node signals)
        ratio (float [0,1]): ratio of nodes to generate walks from, this is used to prevent overfitting
        **kwargs: 
            p,q: parameters for biased random walk
            similarity_matrix (NxN): corr, cov, etc. matrices carrying the node signal (feature) similarities
    return:
        walks (dict): a dict (keys: nodes, values: walks starting from key node)
    """
    walks = {}
    nodes = G.nodes()
    print(f"About walk: {walk_type}, {[(k,v) for k,v in kwargs.items()]}")

    for node in nodes:
        e = np.random.rand()
        if e > ratio:
            continue

        walks[node] = []
        for _ in range(number_of_walks_per_node):
            walk = get_walk(G, node, walk_length, walk_type, similarity_matrix, **kwargs)
            walks[node].append(walk)
    return walks


def get_walk(G, node, walk_length, walk_type, similarity_matrix, **kwargs):
    """
    Helper function of 'generate_walks'. Given walk type and other parameters, 
    generates single walk starting from 'node'
    """

    if walk_type == "random":
        walk = random_walk(G, node, walk_length)
    
    elif walk_type == "random_plus":
        prob = kwargs.get("prob", 0.5)
        if similarity_matrix is None:
            raise ValueError("'similarity_matrix' is required!")
        walk = random_walk_plus(G, node, walk_length, similarity_matrix, prob)
    
    elif walk_type == "brn":
        p = kwargs.get("p", 1.0)
        q = kwargs.get("q", 0.01)
        walk = biased_random_walk(G, node, walk_length, p, q)

    elif walk_type == "brn_plus":
        p = kwargs.get("p", 1.0)
        q = kwargs.get("q", 0.01)
        prob = kwargs.get("prob", 0.5)
        d = kwargs.get("degree", 1)

        if similarity_matrix is None:
            raise ValueError("'similarity_matrix' is required!")
        walk = biased_random_walk_plus(G, node, walk_length, p, q, similarity_matrix, prob, degree=d)

    else:
        raise ValueError("Invalid walk type!")
    
    return walk