import numpy as np
from data_gen import gen_cascade

np.random.seed(234)
nd = 100
rank = 15
non_zero_num = 1
p = 0.8

# --- Generate U (nd x rank) ---
U = np.zeros((nd, rank))
for k in range(rank):
    U[k, k] = np.random.uniform(0.2, 1.5)
for i in range(rank, nd):
    if np.random.rand() < p:
        factor = np.random.choice(rank, size=non_zero_num, replace=False)
        U[i, factor] = np.random.uniform(0.2, 1.5, non_zero_num)

# --- Generate V (nd x rank) ---
V = np.zeros((nd, rank))
for k in range(rank):
    V[k, k] = np.random.uniform(0.2, 1.5)
for i in range(rank, nd):
    if np.random.rand() < p:
        factor = np.random.choice(rank, size=non_zero_num, replace=False)
        V[i, factor] = np.random.uniform(0.2, 1.5, non_zero_num)
for i in range(nd):
    if np.all(V[i, :] == 0):
        factor = np.random.choice(rank, size=non_zero_num, replace=False)
        V[i, factor] = np.random.uniform(0.2, 1.5, non_zero_num)

# Form psi as U @ V.T so that it is sparse and low-rank
B = U @ V.T
B_use = B - np.diag(np.diagonal(B))


# Generate sparse theta with fixed level of overlap with psi
overlap_size = 1
non_overlap_size = 1   

A = np.zeros((nd, nd))
for i in range(nd):
    # Count nonzero entries in column i of B_use
    num = np.sum(B_use[:, i] > 0)
    if num == 0:
        # If no nonzeros from B, add non-overlapping entries
        choices = np.setdiff1d(np.arange(nd), np.array([i]))
        selected = np.random.choice(choices, size=non_overlap_size, replace=False)
        A[selected, i] = np.random.uniform(1, 2, non_overlap_size)
    else:
        # If B has nonzero entries and overlap is desired
        if overlap_size > 0:
            select_size = min(overlap_size, num)
            overlap_indices = np.random.choice(np.where(B_use[:, i] > 0)[0], size=select_size, replace=False)
            
            # Pick values so that A's edge weights differ from B's at these positions
            if abs(np.mean(B_use[overlap_indices, i]) - 0.1) > abs(np.mean(B_use[overlap_indices, i]) - 2):
                A[overlap_indices, i] = np.random.uniform(0.1, 0.2, select_size)
            else:
                A[overlap_indices, i] = np.random.uniform(1.9, 2, select_size)
            
        # Add non-overlapping entries
        if non_overlap_size > 0:
            possible_indices = np.setdiff1d(np.arange(nd), np.where(B_use[:, i] > 0)[0])
            possible_indices = np.setdiff1d(possible_indices, np.array([i]))
            if len(possible_indices) >= non_overlap_size:
                remain_indices = np.random.choice(possible_indices, size=non_overlap_size, replace=False)
                A[remain_indices, i] = np.random.uniform(1, 2, non_overlap_size)
            elif len(possible_indices) > 0:
                A[possible_indices, i] = np.random.uniform(1, 2, len(possible_indices))

# Ensure that every row in the union of A and B has at least one nonzero
for i in range(nd):
    if np.sum((A + B_use)[i, :] > 0) == 0:
        choices = np.setdiff1d(np.arange(nd), np.array([i]))
        selected = np.random.choice(choices, size=1, replace=False)
        A[i, selected] = np.random.uniform(1, 2, 1)

# This experiment uses dist = "exp", "ray", "pow" and nc = 2000, 3000, 4000, 5000
num_case = 20
nc = 5000
t = 10
delta = 1
dist = "exp"

# Generate cascade samples
for ii in range(num_case):
    P_pathway = 0.5 * np.ones(nd)
    Z_record = np.zeros((nc, nd))

    cascades = np.zeros((nc, nd))
    for i in range(nc):
        Z = np.random.binomial(1, P_pathway, size=nd)
        Z_record[i, :] = Z
        Theta = A * np.ones((nd, nd)) @ np.diag(Z) + B_use * np.ones((nd, nd)) @ np.diag(1 - Z)

        cascades[i, :] = gen_cascade(Theta, t, nd, dist, delta)