# losses.py

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F


def info_nce_loss(Z_a, Z_b, tau_nce):
    """
    Node-level InfoNCE Loss
    Z_a, Z_b: [B, n, d]
    returns scalar
    """
    B, n, d = Z_a.shape
    sims = torch.matmul(Z_a, Z_b.transpose(1,2)) / tau_nce  # [B, n, n]
    flat = sims.reshape(B*n, n)                              # [B*n, n]
    targets = torch.arange(n, device=Z_a.device).repeat(B)   # [B*n]
    return F.cross_entropy(flat, targets)


def loss_aug(emb_a, emb_b, emb_c, log_p, weight, V, tau_aug, eps=1e-8):
    """
    Graph-level KL Loss
    emb_a: [B, n, d]
    emb_b: [B, n, d]
    emb_c: [B, K, n, d]
    log_p: Tensor [B] or scalar log p_theta(b|a)
    weight: Tensor [B] or scalar = p_theta(b|a).detach()
    V: int, total number of KD genes
    tau_aug: float
    """
    B, n, d = emb_a.shape
    K = emb_c.size(1)

    # Vectorized dot product
    sim_ab = (emb_a * emb_b).sum(dim=2).sum(dim=1) / n       # [B]
    sims_c = (emb_a.unsqueeze(1) * emb_c).sum(dim=(2,3)) / n # [B, K]

    # Log numerator and denominator
    log_num   = sim_ab / tau_aug                                   # [B]
    log_denom = torch.logsumexp(sims_c / tau_aug, dim=1) + torch.log(torch.tensor(V/K, device=emb_a.device))  # [B]
    log_q     = log_num - log_denom                                # [B]

    # KL = |V| * weight * (log_p - log_q)
    kl = V * weight * (log_p - log_q)                              # [B]
    return kl.mean()

