import torch
import torch.nn.functional as F


def subgraph_contrastive_loss(h, z, adj, tau):
    # intra-view and inter-view losses

    f = lambda x: torch.exp(x / tau)

    v = h[2]

    adj_s = adj + torch.eye(adj.shape[0]).cuda(0)

    sim_node = f(torch.mm(F.normalize(z), F.normalize(v).T))
    sim_graph = f(torch.mm(F.normalize(z), F.normalize(z).T))

    l1 = sim_node.mul(adj_s).sum(1) / sim_node.sum(1)

    l2_numerator = sim_graph.mul(adj).sum(1)
    l2_denominator = (sim_graph - torch.diag_embed(sim_graph.diag())).sum(1)

    l2 = torch.where((l2_numerator != 0) & (l2_denominator != 0), l2_numerator / l2_denominator, torch.ones_like(l2_numerator))

    return -torch.log(l1 * l2).mean()

