import torch

def bregman_divergence(y, y_s, label, loss_func, berman_grad):
    phi_y = loss_func(y, label)  # \phi(y)
    phi_ys = loss_func(y_s, label)  # \phi(y_s)
    divergence_term = torch.dot(berman_grad.view(-1), (y - y_s).view(-1))  # \nabla \phi(y_s) * (y - y_s)
    
    return phi_y - phi_ys - divergence_term

def pbrf_loss(y_r, y_a, y, y_s, label_r, label, loss_func, lamb, theta, theta_s, berman_grad, num_trains):
    divergence = bregman_divergence(y, y_s, label, loss_func, berman_grad)
    
    num_influenced_nodes = label_r.shape[0]
    remove_loss = loss_func(y_r, label_r) * num_influenced_nodes
    add_loss = loss_func(y_a, label_r) * num_influenced_nodes

    pbrf_value = divergence - 1/num_trains * (remove_loss - add_loss) + (lamb / 2) * torch.norm(theta - theta_s, p=2) ** 2 

    return pbrf_value, remove_loss, add_loss

