import torch


def find_block_sum(inp, block_size):
    block_mean = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
    block_indices = torch.repeat_interleave(torch.arange(len(block_size)).to(inp.device), block_size)
    block_mean.index_reduce_(0, block_indices, inp, 'mean', include_self=False)
    block_sum = block_mean * block_size
    return block_sum


def find_block_sum_seq(inp, block_size, find_mean=False):
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(inp.device), block_size]), dim=0).long()
    block_sum = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
    for i in range(len(block_size)):
        if find_mean:
            block_sum[i] = torch.mean(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
        else:
            block_sum[i] = torch.sum(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
    return block_sum


def get_zdist_weight_tsne(zdist, degrees_of_freedom=1):
    dist = zdist ** 2
    dist /= degrees_of_freedom
    dist += 1.
    dist **= (degrees_of_freedom + 1.0) / -2.0
    return dist


def get_zdist_prob_student_t(zdist, size, degrees_of_freedom=1):
    dist = get_zdist_weight_tsne(zdist, degrees_of_freedom)

    block_sum = find_block_sum(dist, size)
    block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
    block_sum = block_sum[block_indices]
    q_ = dist / block_sum
    Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
    return Q


def get_zdist_weight_umap(zdist, _a=1.93, _b=0.79):
    dist = 1 / (1 + _a * zdist ** (2 * _b))
    return dist

def get_zdist_weight_umap_normalized(zdist, size, _a=1.93, _b=0.79):
    dist = get_zdist_weight_umap(zdist, _a, _b)
    block_sum = find_block_sum(dist, size)
    block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
    block_sum = block_sum[block_indices]
    q_ = dist / block_sum
    Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
    return Q


def find_block_ce_seq(q, p, block_size):
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(q.device), block_size]), dim=0).long()
    out = torch.zeros((len(block_size),), dtype=q.dtype, device=q.device)
    for i in range(len(block_size)):
        out[i] = torch.nn.functional.binary_cross_entropy(q[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                                                          p[block_size_cumsum[i]: block_size_cumsum[i + 1]])
        # out[i] = torch.nn.functional.cross_entropy(q[block_size_cumsum[i]: block_size_cumsum[i + 1]],
        #                                                   p[block_size_cumsum[i]: block_size_cumsum[i + 1]])
    return out


def loss_fn_kl_t_pdist(net, n_view_graphs, edge_only_graph, zdist, batch_zdist_size, df=1):
    '''zdist: in pdist form'''
    # print('find p..')
    p = get_zdist_prob_student_t(zdist, batch_zdist_size, degrees_of_freedom=df)
    # print('begin forward')
    z_hat = net.forward(edge_only_graph, n_view_graphs)

    batch_graph_size = n_view_graphs.batch_num_nodes()
    batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
    zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                           for i in range(n_view_graphs.batch_size)])

    # zhat_dist = torch.pdist(z_hat)
    # print('find q..')
    q = get_zdist_prob_student_t(zhat_dist, batch_zdist_size, degrees_of_freedom=df)  # .to(zdist.device)
    # print('find loss')
    #
    kl = p * (torch.log(p) - torch.log(q))
    # kl_loss = kl.sum(dim=1).mean()
    kl_sum = find_block_sum_seq(kl, batch_zdist_size)
    kl_loss = 2 * (df + 1) / df * kl_sum.mean()
    # print('finish find loss')
    return kl_loss, z_hat


def get_pdist_from_cdist(dist_sq):
    n = dist_sq.shape[0]
    i, j = torch.triu_indices(n, n, offset=1)
    # index into D
    pdist_from_cdist = dist_sq[i, j]
    return pdist_from_cdist

def compute_knn(x, n_neighbors):
    dist_sq = torch.cdist(x, x, p=2)**2
    knn_dists, knn_idx = torch.topk(dist_sq, n_neighbors+1, largest=False)
    return knn_idx[:,1:], torch.sqrt(knn_dists[:,1:]), torch.sqrt(dist_sq)


def get_umap_weight_from_knn(cdist, knn_idx):
    # knn_idx = knn_idx[:, 1:]
    rows = torch.arange(cdist.size(0), device=cdist.device).unsqueeze(1)
    # then pick out D[rows, knn_idx]
    knn_value = cdist[rows, knn_idx]
    weight = get_zdist_weight_umap(knn_value)
    return weight


def get_umap_weight_from_x_knn(z, zhat, n_neighbors):
    knn_idx, z_knn_value, z_cdist = compute_knn(z, n_neighbors)
    z_pdist = get_pdist_from_cdist(z_cdist)
    knn_p_weight = get_zdist_weight_umap(z_knn_value)

    zhat_cdist = torch.cdist(zhat, zhat, p=2)
    zhat_pdist = get_pdist_from_cdist(zhat_cdist)
    knn_q_weight = get_umap_weight_from_knn(zhat_cdist, knn_idx)

    global_p_weight = get_zdist_weight_umap(z_pdist)
    global_q_weight = get_zdist_weight_umap(zhat_pdist)
    return knn_p_weight.view(-1), knn_q_weight.view(-1), global_p_weight, global_q_weight


def smooth_knn_dist_vectorized(knn_dists, k, tol=1e-5, max_iter=64):
    """
    向量化并行版本的 sigma / rho 估计，
    输入：
      knn_dists: (n_samples, k) 各点到其 k 个邻居的距离
      k: 邻居数
    输出：
      sigmas: (n_samples,) 每个点的 sigma
      rhos:   (n_samples,) 每个点的 rho
    """
    device = knn_dists.device
    n_samples = knn_dists.size(0)

    # 1) rho[i] = knn_dists[i,0]
    rhos = knn_dists[:, 0].clone()

    # 2) 初始化 lo, hi, mid
    lo    = torch.zeros(n_samples, device=device)
    hi    = knn_dists.max(dim=1).values
    mid   = torch.ones(n_samples, device=device)

    # 目标和： log2(k)
    # target = torch.log2(torch.tensor(k, dtype=torch.float32, device=device))
    target = torch.log2(k.to(torch.float32).to(device))

    for _ in range(max_iter):
        # 3) 计算每个 i 的 exp-sum，shape=(n_samples,)
        psum = torch.exp(-(knn_dists - rhos.unsqueeze(1)) / mid.unsqueeze(1)).sum(dim=1)

        # 4) 根据 psum > target 来更新 lo/hi
        greater = psum > target
        lo = torch.where(greater, lo, mid)
        hi = torch.where(greater, mid, hi)

        # 5) 取新 mid
        mid = (lo + hi) * 0.5

        # 6) 如果所有点都已经收敛，则提前退出
        if torch.all(torch.abs(psum - target) < tol):
            break

    return mid, rhos

def build_fuzzy_graph_vectorized(knn_idx, knn_dists, sigmas, rhos):
    """
    向量化并行版本的构造 fuzzy simplicial set
    knn_idx:  (n, k)   Tensor of neighbor indices
    knn_dists:(n, k)   Tensor of neighbor distances
    sigmas:   (n,)     Tensor of sigma values
    rhos:     (n,)     Tensor of rho values
    返回:
      graph_sym: sparse COO Tensor of shape (n, n)
    """
    n, k = knn_idx.shape
    device = knn_idx.device

    # 1) 展开所有 (i, j) 对
    i = torch.arange(n, device=device).unsqueeze(1).repeat(1, k).reshape(-1)   # (n*k,)
    j = knn_idx.reshape(-1)                                                    # (n*k,)
    d = knn_dists.reshape(-1)                                                  # (n*k,)

    # 2) 展开 rho 和 sigma 对应到每条边
    rho_expand   = rhos.unsqueeze(1).repeat(1, k).reshape(-1)   # (n*k,)
    sigma_expand = sigmas.unsqueeze(1).repeat(1, k).reshape(-1) # (n*k,)

    # 3) 并行计算权重
    w = torch.where(
        d <= rho_expand,
        torch.ones_like(d),
        torch.exp(-(d - rho_expand) / sigma_expand)
    )  # (n*k,)

    # 4) 构造稀疏图
    idx = torch.stack([i, j], dim=0)  # (2, n*k)
    graph = torch.sparse_coo_tensor(idx, w, (n, n)).coalesce()

    # 5) 对称化： W_sym = W + W^T − W * W^T
    graph_sym = (graph + graph.transpose(0,1)
             - graph.multiply(graph.transpose(0,1))).coalesce()

    return graph_sym

def graph_sym2pdist(graph_sym):
    dense = graph_sym.to_dense()  # → (n, n) 的 Dense Tensor
    n = dense.size(0)

    # 生成上三角（offset=1 去掉对角线）的位置索引
    i, j = torch.triu_indices(n, n, offset=1)

    # 按照 (0,1),(0,2),… 的顺序采样
    vals = dense[i, j]
    return vals

def get_zero_index(vals):
    zero_idx = torch.where(vals == 0)[0]  # 返回所有等于 0 的索引
    nonzero_idx = torch.where(vals != 0)[0]  # 返回所有不等于 0 的索引
    return zero_idx, nonzero_idx


def get_umap_n_neighbor_p(z, n_neighbor):
    knn_idx, knn_dists = compute_knn(z, n_neighbors=n_neighbor)
    sigmas, rhos = smooth_knn_dist_vectorized(knn_dists, k=n_neighbor)
    graph = build_fuzzy_graph_vectorized(knn_idx, knn_dists, sigmas, rhos)
    p_pdist = graph_sym2pdist(graph)
    return p_pdist


def get_weight_by_type(z, zhat, batch_graph_size, batch_size, type, n_neighbors=None):
    assert type in ['student_t_norm', 'student_t', 'umap', 'umap_norm', 'umap_n_neighbor', 'umap_knn', 'none']
    batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()

    zhat_dist = torch.cat([torch.pdist(zhat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                           for i in range(batch_size)])
    batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(z.device)
    if type == 'umap_n_neighbor':
        p = torch.cat([get_umap_n_neighbor_p(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]],
                                             n_neighbors[batch_graph_size_cumsum[i]])
                                    for i in range(batch_size)])
        q = get_zdist_weight_umap(zhat_dist)
    elif type == 'umap_knn':
        # calculate both p q weight using knn
        p = []
        q = []
        global_p = []
        global_q = []
        batch_zdist_size = []
        for i in range(batch_size):
            _p, _q, _global_p, _global_q = get_umap_weight_from_x_knn(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]],
                                              zhat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]],
                                              n_neighbors[batch_graph_size_cumsum[i]])
            assert _p.shape[0] == _q.shape[0]
            batch_zdist_size.append(_p.shape[0])
            p.append(_p)
            q.append(_q)
            global_p.append(_global_p)
            global_q.append(_global_q)
        p = torch.cat(p)
        q = torch.cat(q)
        global_p = torch.cat(global_p)
        global_q = torch.cat(global_q)
        batch_zdist_size = torch.tensor(batch_zdist_size).to(q.device)
        return p, q, batch_zdist_size, global_p, global_q

    else:
        zdist = torch.cat([torch.pdist(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                                for i in range(batch_size)])
        if type == 'none':
            p = zdist
            q = zhat_dist
        elif type == 'student_t_norm':
            p = get_zdist_prob_student_t(zdist, batch_zdist_size, degrees_of_freedom=1)
            q = get_zdist_prob_student_t(zhat_dist, batch_zdist_size, degrees_of_freedom=1)
        elif type == 'student_t':
            p = get_zdist_weight_tsne(zdist, degrees_of_freedom=1)
            q = get_zdist_weight_tsne(zhat_dist, degrees_of_freedom=1)
        elif type == 'umap':
            p = get_zdist_weight_umap(zdist)
            q = get_zdist_weight_umap(zhat_dist)
        elif type == 'umap_norm':
            p = get_zdist_weight_umap_normalized(zdist, batch_zdist_size)
            q = get_zdist_weight_umap_normalized(zhat_dist, batch_zdist_size)
        else:
            raise NotImplementedError
    return p, q

def find_block_umap_ce_seq(q, p, block_size):
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(q.device), block_size]), dim=0).long()
    out = torch.zeros((len(block_size),), dtype=q.dtype, device=q.device)
    for i in range(len(block_size)):
        _q = q[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _p = p[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        zero_idx, nonzero_idx = get_zero_index(_p)
        nonzero_term = -1 * _p[nonzero_idx] * torch.log(_q[nonzero_idx])
        zero_term = -1 * torch.log(_q[zero_idx])
        umap_ce = torch.sum(nonzero_term) + torch.sum(zero_term)
        out[i] = 0.1 * umap_ce
    return out


def _get_loss_by_type(p, q, batch_zdist_size, type, df=1):
    '''
    p: prob/weight from z
    q: prob/weight from zhat
    '''
    assert type in ['kl', 'bce', 'l2', 'umap_ce']
    if type == 'kl':
        kl = p * (torch.log(p) - torch.log(q))
        kl_sum = find_block_sum_seq(kl, batch_zdist_size)
        loss = 2 * (df + 1) / df * kl_sum.mean()
    elif type == 'bce':
        bce_sum = find_block_ce_seq(q, p, batch_zdist_size)
        loss = bce_sum.mean()
    elif type == 'l2':
        l2 = (q - p) ** 2  # l2
        l2_sum = find_block_sum_seq(l2, batch_zdist_size, find_mean=True)
        loss = l2_sum.mean()
    elif type == 'umap_ce':
        ce = find_block_umap_ce_seq(q, p, batch_zdist_size)
        loss = ce.mean()
    else:
        raise NotImplementedError
    return loss


def get_loss_by_type(z, zhat, batch_graph_size, batch_size, type, weight_type, n_neighbors=None, df=1, t=0):
    '''
    p: prob/weight from z
    q: prob/weight from zhat
    '''
    assert type in ['kl', 'bce', 'l2', 'umap_ce', 'umap_knn_ce', 'umap_knn_l2']
    if type == 'umap_knn_ce':
        p, q, batch_zdist_size = get_weight_by_type(z, zhat, batch_graph_size, batch_size,
                                                    type=weight_type, n_neighbors=n_neighbors)
        bce_sum = -1 * p * torch.log(q)
        bce_sum = find_block_sum_seq(bce_sum, batch_zdist_size, find_mean=False)
        loss = bce_sum.mean()
    elif type == 'umap_knn_l2':
        p, q, knn_batch_zdist_size, global_p, global_q = get_weight_by_type(z, zhat, batch_graph_size, batch_size,
                                                    type=weight_type, n_neighbors=n_neighbors)
        knn_l2 = (q - p) ** 2  # l2
        # print(l2)
        knn_l2_sum = find_block_sum_seq(knn_l2, knn_batch_zdist_size, find_mean=True)
        knn_loss = knn_l2_sum.mean()

        batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(z.device)
        l2 = (global_q - global_p) ** 2  # l2
        l2_sum = find_block_sum_seq(l2, batch_zdist_size, find_mean=True)
        global_loss = l2_sum.mean()
        loss = knn_loss + 1.0 * (t / 100 + 1e-6) * global_loss
    else:
        p, q = get_weight_by_type(z, zhat, batch_graph_size, batch_size, type=weight_type, n_neighbors=n_neighbors)
        batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(z.device)
        if type == 'kl':
            kl = p * (torch.log(p) - torch.log(q))
            kl_sum = find_block_sum_seq(kl, batch_zdist_size)
            loss = 2 * (df + 1) / df * kl_sum.mean()
        elif type == 'bce':
            bce_sum = find_block_ce_seq(q, p, batch_zdist_size)
            loss = bce_sum.mean()
        elif type == 'l2':
            l2 = (q - p) ** 2  # l2
            # print(l2)
            l2_sum = find_block_sum_seq(l2, batch_zdist_size, find_mean=True)
            loss = l2_sum.mean()
        elif type == 'umap_ce':
            ce = find_block_umap_ce_seq(q, p, batch_zdist_size)
            loss = ce.mean()
        else:
            raise NotImplementedError
    return loss


def query_weight_type(loss_type, method='tsne'):
    tsne_d = {'kl': 'student_t_norm',
              'l2': 'student_t',
              }
    umap_d = {
        'kl': 'student_t_norm', # 'umap_norm' will be modified
        # 'kl': 'umap_norm', # 'umap_norm' will be modified
              'l2': 'none',  #'umap',
              'bce': 'umap',
              'umap_ce': 'umap_n_neighbor',
              'umap_knn_ce': 'umap_knn',
              'umap_knn_l2': 'umap_knn',
              }
    if method == 'tsne':
        return tsne_d[loss_type]
    elif method == 'umap':
        return umap_d[loss_type]
    else:
        raise NotImplementedError


def loss_fn_pdist_w_umap(net, n_view_graphs, edge_only_graph, tsne_z, umap_z,
                         tsne_loss='kl', umap_loss='kl',
                         df=1):
    '''zdist: in pdist form'''

    z_hat = net.forward(edge_only_graph, n_view_graphs)
    tsne_z_hat = z_hat[:, :2]
    umap_z_hat = z_hat[:, 2:]

    batch_graph_size = n_view_graphs.batch_num_nodes()
    # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
    batch_size = n_view_graphs.batch_size
    # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

    tsne_weight_type = query_weight_type(tsne_loss, method='tsne')
    umap_weight_type = query_weight_type(umap_loss, method='umap')

    tsne_loss = get_loss_by_type(tsne_z, tsne_z_hat, batch_graph_size, batch_size,
                                 type=tsne_loss, weight_type=tsne_weight_type)
    umap_loss = get_loss_by_type(umap_z, umap_z_hat, batch_graph_size, batch_size,
                                 type=umap_loss, weight_type=umap_weight_type,
                                 n_neighbors=n_view_graphs.ndata['umap_n_neighbor'])

    kl_loss = tsne_loss + umap_loss
    # print(f'kl_tsne_loss: {kl_tsne_loss}, kl_umap_loss: {kl_umap_loss}')
    return kl_loss, z_hat


def loss_fn_pdist(net, n_view_graphs, edge_only_graph, z,
                         loss_type, method,
                         df=1, t=1):
    '''zdist: in pdist form'''

    z_hat = net.forward(edge_only_graph, n_view_graphs)
    # tsne_z_hat = z_hat[:, :2]
    # umap_z_hat = z_hat[:, 2:]

    batch_graph_size = n_view_graphs.batch_num_nodes()
    # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
    batch_size = n_view_graphs.batch_size
    # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

    weight_type = query_weight_type(loss_type, method=method)

    loss = get_loss_by_type(z, z_hat, batch_graph_size, batch_size,
                                 type=loss_type, weight_type=weight_type,
                            n_neighbors=n_view_graphs.ndata['umap_n_neighbor'],
                            t=t
                            )

    kl_loss = loss
    # print(loss)
    # print(f'kl_tsne_loss: {kl_tsne_loss}, kl_umap_loss: {kl_umap_loss}')
    return kl_loss, z_hat