import numpy as np
from data_gen import gen_cascade

np.random.seed(234)

# This experiment uses dist = "exp", "ray", "pow" and nc = 500, 1000, 1500, 2000
dist = "pow"
delta = 1.
nd = 200
num_case = 20
nc = 2000 
t = 10


# Generate a low-rank and sparse latent diffusion network psi
rank = 5
B_sparsity = 0.1
non_zero_num = np.ceil(B_sparsity*rank*nd)
low_rank_B_1 = np.zeros(rank*nd)
non_zero_index_1 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace= False)
low_rank_B_1[non_zero_index_1] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_1 = low_rank_B_1.reshape(nd,rank)

low_rank_B_2 = np.zeros(rank*nd)
non_zero_index_2 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace= False)
low_rank_B_2[non_zero_index_2] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_2 = low_rank_B_2.reshape(nd,rank)
B = low_rank_B_1@(low_rank_B_2.transpose())
B_use = B - np.diag(np.diag(B))


# Generate a sparse theta with fixed overlap with psi
overlap_size = 1
non_overlap_size = 1
A = np.zeros((nd,nd))
for i in range(nd):
    num = np.sum(B_use>0,0)[i]
    if num == 0:
        A[np.random.choice(np.setdiff1d(np.array(range(nd)),i),size = non_overlap_size),i] = np.random.uniform(1,5,non_overlap_size)
    elif (num>0)&(num<=3):
        select_index = np.random.choice(np.where(B_use[:,i]>0)[0], size = overlap_size)
        A[select_index,i]  = np.random.uniform(2,5,overlap_size)
        remain_index = np.random.choice(np.setdiff1d(np.setdiff1d(np.array(range(nd)),np.where(B_use[:,i]>0)[0]),i),size = non_overlap_size)
        A[remain_index,i] = np.random.uniform(2,5,non_overlap_size)
    else:
        select_index = np.random.choice(np.where(B_use[:,i]>0)[0], size = overlap_size)
        A[select_index,i]  = np.random.uniform(2,5,overlap_size)
        remain_index = np.random.choice(np.setdiff1d(np.setdiff1d(np.array(range(nd)),np.where(B_use[:,i]>0)[0]),i),size = non_overlap_size)
        A[remain_index,i] = np.random.uniform(2,5,non_overlap_size)

for i in range(nd):
    num = np.sum((A+B_use)>0,1)[i]
    if num == 0:
        A[i,np.random.choice(np.setdiff1d(np.array(range(nd)),i),size = 1)] = np.random.uniform(1,5,1)


# Generate cascade samples
for ii in range(num_case):
    P_pathway = 0.5 * np.ones(nd)
    Z_record = np.zeros((nc, nd))

    cascades = np.zeros((nc, nd))
    for i in range(nc):
        Z = np.random.binomial(1, P_pathway, size=nd)
        Z_record[i, :] = Z
        Theta = A * np.ones((nd, nd)) @ np.diag(Z) + B_use * np.ones((nd, nd)) @ np.diag(1 - Z)

        cascades[i, :] = gen_cascade(Theta, t, nd, dist, delta)

    
