import numpy as np
import torch
from data_gen_gpu import gen_cascades_gpu

np.random.seed(234)
torch.manual_seed(234)

dist = "exp"
delta = 1.

nd = 2000
num_case = 5
nc = 50000
t = 10

rank = 5
B_sparsity = 0.04
non_zero_num = np.ceil(B_sparsity*rank*nd)
low_rank_B_1 = np.zeros(rank*nd)
non_zero_index_1 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace= False)
low_rank_B_1[non_zero_index_1] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_1 = low_rank_B_1.reshape(nd,rank)

low_rank_B_2 = np.zeros(rank*nd)
non_zero_index_2 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace= False)
low_rank_B_2[non_zero_index_2] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_2 = low_rank_B_2.reshape(nd,rank)
B = low_rank_B_1@(low_rank_B_2.transpose())
B_use = B - np.diag(np.diag(B))

## generate A with different overlap with B
overlap_size = 1
non_overlap_size = 2
A = np.zeros((nd,nd))
for i in range(nd):
    num = np.sum(B_use>0,0)[i]
    if num == 0:
        A[np.random.choice(np.setdiff1d(np.array(range(nd)),i),size = non_overlap_size),i] = np.random.uniform(1,5,non_overlap_size)
    elif (num>0)&(num<=3):
        select_index = np.random.choice(np.where(B_use[:,i]>0)[0], size = overlap_size)
        A[select_index,i]  = np.random.uniform(2,5,overlap_size)
        remain_index = np.random.choice(np.setdiff1d(np.setdiff1d(np.array(range(nd)),np.where(B_use[:,i]>0)[0]),i),size = non_overlap_size)
        A[remain_index,i] = np.random.uniform(2,5,non_overlap_size)
    else:
        select_index = np.random.choice(np.where(B_use[:,i]>0)[0], size = overlap_size)
        A[select_index,i]  = np.random.uniform(2,5,overlap_size)
        remain_index = np.random.choice(np.setdiff1d(np.setdiff1d(np.array(range(nd)),np.where(B_use[:,i]>0)[0]),i),size = non_overlap_size)
        A[remain_index,i] = np.random.uniform(2,5,non_overlap_size)

for i in range(nd):
    num = np.sum((A+B_use)>0,1)[i]
    if num == 0:
        A[i,np.random.choice(np.setdiff1d(np.array(range(nd)),i),size = 1)] = np.random.uniform(1,5,1)

np.savez_compressed("true_A.npz", A=A.astype(np.float32, copy=False))
np.savez_compressed("true_B.npz", B=B.astype(np.float32, copy=False))
A_test = np.load("true_A.npz")["A"]
B_test = np.load("true_B.npz")["B"]
print(np.allclose(A, A_test), np.allclose(B, B_test))


for ii in range(num_case):

    print(ii)

    P_pathway = 0.5 * np.ones(nd)
    cascades = gen_cascades_gpu(
    A, B_use, P_pathway,
    t=t, nc=nc,
    dist=dist, delta=delta,
    device='cuda',    # 或 'cpu'
    chunk_size=1000    # 根据显存调整
    )

    np.savez_compressed(dist + "_cascade_{}.npz".format(ii+1), cascades=cascades.astype(np.float32, copy=False))
    print(np.allclose(np.load(dist + "_cascade_{}.npz".format(ii+1))["cascades"], cascades))


    
