import numpy as np
import random
from inv_sampling import inv_exp, inv_pow, inv_ray
import warnings

"""

Function gen_cascade takes inputs
alpha -- a (nd x nd) square matrix, the ji^th entry being the transmission rate from node j to node i 
t -- time interval for sampling
nd -- number of nodes
dist -- "exp" or "pow" or "ray", the distribution to sample time from
and outputs
cas -- a sample of a single cascade

Function gen_trans_matrix takes inputs
nd -- number of nodes
rt_low -- lower bound of non-zero transmission rates
rt_high -- upper bound of non-zero transmission rates
sparsity -- level of sparsity, default is 0.9

"""

def gen_cascade(alpha, t, nd, dist="exp", delta=1.):
    # select a transmission distribution to sample from
    sampler = None
    if dist == "exp":
        sampler = inv_exp
    elif dist == "pow":
        sampler = inv_pow
    elif dist == "ray":
        sampler = inv_ray

    # select a source node and initialize the problem
    cas = np.ones(nd) * t
    uninfected = np.ones(nd, dtype=np.bool_)
    source_idx = np.random.randint(0, nd)
    cas[source_idx], uninfected[source_idx] = 0.0, False
    last_idx = source_idx

    # sample infection time of the uninfected nodes by the last infected node
    # choose the next infected node with the lowest infection time by all the previously infected nodes
    # update infection status and cascade
    while True:
        if dist == "pow":
            ti = sampler(alpha[last_idx][uninfected], cas[last_idx], delta)
        else:
            ti = sampler(alpha[last_idx][uninfected], cas[last_idx])
        if np.sum(np.min(np.vstack((cas[uninfected], ti)), axis=0) < t) == 0:
            break
        cas[uninfected] = np.min(np.vstack((cas[uninfected], ti)), axis=0)
        last_idx = int(np.argwhere(cas[uninfected] == np.min(cas[uninfected])))
        last_idx = int(np.argwhere(uninfected == 1)[last_idx])
        uninfected[last_idx] = False

    return cas