


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.tensor as tensor



class PULoss(nn.Module):
    def __init__(self,  prior, loss=(lambda x: torch.sigmoid(x)), gamma=1, beta=0,
                 ):
        super(PULoss, self).__init__()
        self.gamma = gamma
        self.beta = beta
        self.loss_func = loss  # lambda x: (torch.tensor(1., device=x.device) - torch.sign(x))/torch.tensor(2, device=x.device)
        self.unlabeled = torch.tensor([-1, 1.],device='cuda')
        self.min_count = torch.tensor(1.,device='cuda')

    def forward(self, output, target, test=False, theta=0.9):
        assert (output.shape == target.shape)
        
        unlabeled_ts = self.unlabeled
        unlabeled_ts = unlabeled_ts.repeat(int(len(target.view(-1))/len(unlabeled_ts)))
        
        positive = target.clone().reshape(-1)
        positive[positive < 0] = 0
        unlabeled = 1 - positive
        
        n_positive, n_unlabeled = torch.sum(positive), torch.sum(unlabeled)
        
        #Loss
        target = target.view(-1)
        logit = lambda x: torch.softmax(x, dim=-1)
        logit_value = logit(output).view(-1)
        logit_value = 2*logit_value-1
        loss = lambda x: torch.sigmoid(-x)

        positive_z = logit_value * positive * target
        positive_negative_z = logit_value * positive * unlabeled_ts
        unlabeled_z = logit_value * unlabeled * target

        positive_risk = loss(positive_z) * positive
        positive_negative_risk = loss(positive_negative_z)* positive
        unlabeled_risk = loss(unlabeled_z)* unlabeled

        positive_risk = positive_risk.sum() / n_positive
        positive_negative_risk = positive_negative_risk.sum() / n_positive
        unlabeled_risk = unlabeled_risk.sum()  / n_unlabeled

        prior = theta

        judge_risk = unlabeled_risk - prior * positive_negative_risk#
        pu_risk = prior * positive_risk + max(tensor([0.], device='cuda'), judge_risk)
        if judge_risk > -self.beta:
            return pu_risk
        else:
            return -1*judge_risk