import torch
import numpy as np

import torch
import numpy as np

import torch
import numpy as np

def gen_cascades_gpu(
    A: np.ndarray,
    B_use: np.ndarray,
    P_pathway: np.ndarray,
    t: float,
    nc: int,
    dist: str = 'exp',
    delta: float = 1.0,
    device: str = None,
    chunk_size: int = None
) -> np.ndarray:
    """
    并行生成 cascades，内部高效采样 Z。

    参数
    ----
    A           : (nd,nd) ndarray，Z=1 时的传染率矩阵
    B_use       : (nd,nd) ndarray，Z=0 时的传染率矩阵
    P_pathway   : (nd,) ndarray，每个节点选 A 的概率
    t           : float，时间截断
    nc          : int，cascade 数量
    dist        : 'exp'|'ray'|'pow'
    delta       : float，只对 'pow' 有效
    device      : 'cuda'|'cpu'，若 None 则自动选
    chunk_size  : int，每批并行处理多少条 cascade；若 None 则全量

    返回
    ----
    cascades    : (nc,nd) ndarray，所有 cascade 感染时间
    """
    # 1) 设备
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    # 2) 转 tensor
    A_t = torch.as_tensor(A,    dtype=torch.float32, device=device)
    B_t = torch.as_tensor(B_use,dtype=torch.float32, device=device)
    P_t = torch.as_tensor(P_pathway, dtype=torch.float32, device=device)
    nc   = int(nc)
    nd   = A_t.size(0)

    # 3) 分块长度
    if chunk_size is None or chunk_size > nc:
        chunk_size = nc

    # 4) 结果存放
    cascades = np.empty((nc, nd), dtype=np.float32)
    batch_idxs = torch.arange(chunk_size, device=device)

    # 5) 按块生成
    for start in range(0, nc, chunk_size):
        end = min(start + chunk_size, nc)
        B = end - start

        # 5.1) 并行采 Z
        #    shape (B, nd)
        Z_chunk = torch.bernoulli(P_t.unsqueeze(0).expand(B, nd))

        # 5.2) 初始化
        cas   = torch.full((B, nd), t,    device=device)
        uninf = torch.ones((B, nd), dtype=torch.bool, device=device)

        # 5.3) 随机选源
        src = torch.randint(0, nd, (B,), device=device)
        batch = torch.arange(B, device=device)
        cas[batch, src]   = 0.0
        uninf[batch, src] = False
        last = src.clone()

        # 5.4) 迭代
        for _ in range(nd - 1):
            A_last = A_t[last]     # (B,nd)
            B_last = B_t[last]
            alpha  = A_last * Z_chunk + B_last * (1 - Z_chunk)

            r      = torch.rand_like(alpha)
            t_last = cas[batch, last].unsqueeze(1)  # (B,1)
            α      = alpha.clamp(min=1e-8)

            if dist == 'exp':
                ti = t_last - torch.log(1 - r) / α
            elif dist == 'ray':
                ti = t_last + torch.sqrt(-2 * torch.log(1 - r) / α)
            elif dist == 'pow':
                ti = t_last + delta * torch.exp(-torch.log(1 - r) / α)
            else:
                raise ValueError(f"Unknown dist '{dist}'")

            ti = torch.where(uninf, ti, torch.full_like(ti, float('inf')))
            cand = torch.min(cas, ti)

            # mask 已感染，再找最小
            masked = torch.where(uninf, cand, torch.full_like(cand, float('inf')))
            min_vals, idxs = masked.min(dim=1)
            cont = min_vals < t
            if not cont.any():
                break

            cas   = torch.where(cont.unsqueeze(1), cand, cas)
            last  = torch.where(cont, idxs, last)
            uninf[batch, last] = False

        cascades[start:end] = cas.cpu().numpy()

    return cascades


# # ========== 使用示例 ==========
# if __name__ == "__main__":
#     # 参数示例
#     nd, nc = 1000, 5000
#     t = 10.0
#     dist, delta = 'exp', 1.0
#     A = np.random.rand(nd, nd).astype(np.float32)
#     B = np.random.rand(nd, nd).astype(np.float32)
#     P = 0.5 * np.ones(nd, dtype=np.float32)

#     # 并行生成
#     casc = gen_cascades_gpu(
#         A, B, P,
#         t=t, nc=nc,
#         dist=dist, delta=delta,
#         device='cuda',    # 或 'cpu'
#         chunk_size=512    # 根据显存调整
#     )
#     print("Done, shape:", casc.shape)
