class ReweightedInfoNCELoss_full(_NCE):
    def forward(
        self,
        T_joint: torch.Tensor,
        T_product: torch.Tensor,
        log_weights_p: torch.Tensor,
        log_weights_q: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch_size = T_joint.shape[0]
        T_product = T_product.view((batch_size, batch_size))

        w = log_weights_p.exp()
        T_product = T_product + log_weights_p.unsqueeze(1)
        logsumexp = torch.logsumexp(T_product - T_joint[None, :], dim=0)
        loss = w * (logsumexp - math.log(batch_size))
        return loss.mean()


def negatives_only_logweights(log_weights: torch.Tensor, eps: float = 1e-12):
    B = log_weights.shape[0]
    log_eps = torch.tensor(eps, device=log_weights.device, dtype=log_weights.dtype).log()
    lse = torch.logsumexp(log_weights, dim=0)
    log_sum_minus_i = lse + torch.log1p(-torch.exp(log_weights - lse))
    log_den = torch.logaddexp(log_sum_minus_i, log_eps)
    log_gamma = math.log(B - 1) - log_den
    return log_gamma.unsqueeze(0) + log_weights.unsqueeze(1)


class ReweightedInfoNCELoss_validLOO(_NCE):
    def forward(
        self,
        T_joint: torch.Tensor,
        T_product: torch.Tensor,
        log_weights_p: torch.Tensor,
        log_weights_q: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch_size = T_joint.shape[0]
        T_product = T_product.view((batch_size, batch_size))
        w = log_weights_p.exp()
        log_neg_w = negatives_only_logweights(log_weights_p)
        log_neg_w.fill_diagonal_(-float("inf"))  # drop T_ii from denom entirely
        T_product = T_product + log_neg_w

        logsumexp = torch.logsumexp(T_product - T_joint[None, :], dim=0)
        loss = w * (logsumexp - math.log(batch_size - 1))  # K-1 since we dropped T_ii
        return loss.mean()


class ReweightedInfoNCELoss_biased(_NCE):
    def forward(
        self,
        T_joint: torch.Tensor,
        T_product: torch.Tensor,
        log_weights_p: torch.Tensor,
        log_weights_q: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch_size = T_joint.shape[0]
        T_product = T_product.view((batch_size, batch_size))
        w = log_weights_p.exp()
        log_neg_w = negatives_only_logweights(log_weights_p)
        log_neg_w.fill_diagonal_(0.0)  # use T_ii with weight 1 in denominator
        T_product = T_product + log_neg_w

        logsumexp = torch.logsumexp(T_product - T_joint[None, :], dim=0)
        loss = w * (logsumexp - math.log(batch_size))
        return loss.mean()
