import numpy as np
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils import erdos_renyi_graph
from torch_geometric.utils import stochastic_blockmodel_graph
from torch_geometric.utils import barabasi_albert_graph
from data_gen import gen_cascade

np.random.seed(234)
torch.manual_seed(234)

dist = "exp"
nd = 200
num_case = 20
nc = 2000
t = 10
delta = 1.


# Generate a low-rank and sparse latent diffusion network psi
rank = 5
B_sparsity = 0.1
non_zero_num = np.ceil(B_sparsity*rank*nd)
low_rank_B_1 = np.zeros(rank*nd)
non_zero_index_1 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace=False)
low_rank_B_1[non_zero_index_1] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_1 = low_rank_B_1.reshape(nd,rank)

low_rank_B_2 = np.zeros(rank*nd)
non_zero_index_2 = np.random.choice(np.array(range(rank*nd)),size =  int(non_zero_num), replace=False)
low_rank_B_2[non_zero_index_2] = np.random.uniform(1, 2, int(non_zero_num))
low_rank_B_2 = low_rank_B_2.reshape(nd,rank)
B = low_rank_B_1@(low_rank_B_2.transpose())
B_use = B - np.diag(np.diag(B))
supp_B = (B_use>0) + 0


### random graph model for theta
### make sure that the overlap between theta and psi is controlled
overlap = 0
while (overlap < 0.049) or (overlap > 0.051):

    # Generate an Erdos-Renyi graph until the adjacency matrix is of size (nd, nd)
    num = 0
    while num != nd:
        edge_index = erdos_renyi_graph(nd, 0.001)
        adj_rand = np.squeeze(to_dense_adj(edge_index).numpy())
        num = adj_rand.shape[0]

    # Ensure every row has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_rand[i, :])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_rand[i, :] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_rand[i, j] = 1

    # Ensure every column has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_rand[:, i])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_rand[:, i] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_rand[j, i] = 1

    # Compute the support matrix for theta
    supp_A = (adj_rand > 0).astype(int)
    
    # Calculate the overlap with psi
    overlap = max(
        np.sum(supp_A * supp_B) / np.sum(supp_A > 0),
        np.sum(supp_A * supp_B) / np.sum(supp_B > 0)
    )
    # print(f"Current overlap: {overlap}")

# After obtaining the desired overlap, assign random weights.
A_rand = (np.random.uniform(1, 5, nd**2).reshape(nd, nd)) * adj_rand


### community structure for theta
block_sizes = [50, 50, 50, 50]
edge_probs = [[0.05, 0.01, 0.01, 0.01],
              [0.01, 0.05, 0.01, 0.01],
              [0.01, 0.01, 0.05, 0.01],
              [0.01, 0.01, 0.01, 0.05]]
edge_probs = np.array(edge_probs) * 0.2

overlap = 0
while (overlap < 0.049) or (overlap > 0.051):
   
    num = 0
    while num != nd:
        edge_index = stochastic_blockmodel_graph(block_sizes , edge_probs)
        adj_com = np.squeeze(to_dense_adj(edge_index).numpy())
        num = adj_com.shape[0]

    # Ensure every row has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_com[i, :])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_com[i, :] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_com[i, j] = 1

    # Ensure every column has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_com[:, i])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_com[:, i] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_com[j, i] = 1

    # Compute the support matrix for theta
    supp_A = (adj_com > 0).astype(int)
    
    # Calculate the overlap with psi
    overlap = max(
        np.sum(supp_A * supp_B) / np.sum(supp_A > 0),
        np.sum(supp_A * supp_B) / np.sum(supp_B > 0)
    )
    # print(f"Current overlap: {overlap}")

# After obtaining the desired overlap, assign random weights.
A_com = (np.random.uniform(1, 5, nd**2).reshape(nd, nd)) * adj_com


### scale-free structure for theta
overlap = 0
while (overlap < 0.049) or (overlap > 0.051):

    num = 0
    while num != nd:
        edge_index = barabasi_albert_graph(num_nodes=nd, num_edges=1)
        adj_scale = np.squeeze(to_dense_adj(edge_index).numpy())
        num = adj_scale.shape[0]

    # Ensure every row has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_scale[i, :])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_scale[i, :] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_scale[i, j] = 1

    # Ensure every column has at least 2 edges.
    for i in range(nd):
        degree = np.sum(adj_scale[:, i])
        if degree < 2:
            missing_edges = int(2 - degree)
            # Exclude self-loop and nodes that already have an edge with node i.
            current_neighbors = np.where(adj_scale[:, i] > 0)[0]
            candidate_nodes = np.setdiff1d(np.arange(nd), np.concatenate(([i], current_neighbors)))
            # In case the candidate set is smaller than the missing edges
            if candidate_nodes.size < missing_edges:
                missing_edges = candidate_nodes.size
            # Randomly select nodes from candidates to add an edge
            new_edges = np.random.choice(candidate_nodes, size=missing_edges, replace=False)
            for j in new_edges:
                adj_scale[j, i] = 1

    pruned = True
    while pruned:
        pruned = False
        for i in range(nd):
            for j in range(i+1, nd):  # only check one direction
                if adj_scale[i, j] == 1:
                    # Tentatively remove the edge
                    adj_scale[i, j] = 0
                    adj_scale[j, i] = 0

                    # Check if both rows/columns still have at least 2 edges
                    if (np.min(np.sum(adj_scale, axis=1)) < 2 or 
                        np.min(np.sum(adj_scale, axis=0)) < 2):
                        # If not, revert removal
                        adj_scale[i, j] = 1
                        adj_scale[j, i] = 1
                        continue

                    # Check if the overlap is still within the target range
                    supp_A = (adj_scale > 0).astype(int)
                    new_overlap = max(
                    np.sum(supp_A * supp_B) / np.sum(supp_A > 0),
                    np.sum(supp_A * supp_B) / np.sum(supp_B > 0)
                    )
                    if 0.049 <= new_overlap <= 0.051:
                        pruned = True
                        # Accept removal and break to restart scanning from beginning
                        break
                    else:
                        # Revert removal if overlap is off-target
                        adj_scale[i, j] = 1
                        adj_scale[j, i] = 1
            if pruned:
                break  # restart scanning after a successful removal

    # Compute the support matrix for theta
    supp_A = (adj_scale > 0).astype(int)
    
    # Calculate the overlap with psi
    overlap = max(
        np.sum(supp_A * supp_B) / np.sum(supp_A > 0),
        np.sum(supp_A * supp_B) / np.sum(supp_B > 0)
    )
    # print(f"Current overlap: {overlap}")

# After obtaining the desired overlap, assign random weights.
A_scale = (np.random.uniform(1, 5, nd**2).reshape(nd, nd)) * adj_scale


# Generate cascade samples
A = A_scale

for ii in range(num_case):

    P_pathway = 0.5 * np.ones(nd)
    Z_record = np.zeros((nc, nd))

    cascades = np.zeros((nc, nd))
    for i in range(nc):
        Z = np.random.binomial(1, P_pathway, size=nd)
        Z_record[i, :] = Z
        Theta = A * np.ones((nd, nd)) @ np.diag(Z) + B_use * np.ones((nd, nd)) @ np.diag(1 - Z)

        cascades[i, :] = gen_cascade(Theta, t, nd, dist, delta)


