"""
If you use LibAUC in your works, please cite our papers. 
"""
import torch 
import torch.nn.functional as F

class AUCMLoss(torch.nn.Module):
    """
    AUCM Loss with squared-hinge function: a novel loss function to directly optimize AUROC
    
    inputs:
        margin: margin term for AUCM loss, e.g., m in [0, 1]
        imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples
    outputs:
        loss value 
    
    Reference: 
        Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 
        Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 
        International Conference on Computer Vision (ICCV 2021)
    Link:
        https://arxiv.org/abs/2012.03173
    """
    def __init__(self, margin=1.0, imratio=None):
        super(AUCMLoss, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin
        self.p = imratio
        # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device,  requires_grad=True).to(self.device) #.cuda()
        self.alpha = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        
    def forward(self, y_pred, y_true):
        if self.p is None:
           self.p = (y_true==1).float().sum()/y_true.shape[0]   
        f_ps = y_pred[y_true==1].reshape(-1,1) 
        f_ns = y_pred[y_true==0].reshape(-1,1) 
        y_pred = y_pred.reshape(-1, 1) # be carefull about these shapes
        y_true = y_true.reshape(-1, 1) 
        tmp = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \
                    self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float()) 
        loss = tmp + 2*self.alpha*(self.p*(1-self.p)*self.margin + \
                    torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \
                    self.p*(1-self.p)*self.alpha**2
        loss = loss/((1-self.p)*self.p)
        cen_ps = torch.mean(f_ps)
        cen_ns = torch.mean(f_ns)
        real_loss_1 = tmp/((1-self.p)*self.p) 
        real_loss_2 = self.margin - cen_ps + cen_ns 
        return loss, real_loss_1, real_loss_2
    
    
class CLLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(CLLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.c = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.d = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device) # self.d = exp-smooth(cen_ps-cen_ns)

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        cen_ps = torch.mean(f_ps)
        cen_ns = torch.mean(f_ns)
        tmp = -self.margin*(cen_ps-cen_ns)
        loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+(torch.exp(self.d)/(1+torch.exp(self.d)))*tmp+self.c*tmp
        real_loss_1 = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)
        real_loss_2 = tmp
        return loss, real_loss_1, real_loss_2
        #real_loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+torch.log(1+torch.exp(tmp))
        #return loss, real_loss


class CSQLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(CSQLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.c = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.d = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device)+1 # self.d = exp-smooth(cen_ps-cen_ns)

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        cen_ps = torch.mean(f_ps)
        cen_ns = torch.mean(f_ns)
        tmp = cen_ps-cen_ns-self.margin
        loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+2*self.d*tmp+self.c*tmp 
        real_loss_1 = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)
        real_loss_2 = -tmp
        return loss, real_loss_1, real_loss_2

    
class CHLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(CHLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.c = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.d = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device) # self.d = exp-smooth(cen_ps-cen_ns)

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        cen_ps = torch.mean(f_ps)
        cen_ns = torch.mean(f_ns)
        tmp = self.margin-(cen_ps-cen_ns)
        loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+(self.d > torch.tensor(0.0))*tmp+self.c*tmp
        real_loss_1 = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)
        real_loss_2 = tmp
        return loss, real_loss_1, real_loss_2
        #real_loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+torch.maximum(tmp,torch.tensor(0.0))
        #return loss, real_loss

    
class CSHLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(CSHLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.c = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        self.d = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device) # self.d = exp-smooth(cen_ps-cen_ns)

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        cen_ps = torch.mean(f_ps)
        cen_ns = torch.mean(f_ns)
        tmp = self.margin-(cen_ps-cen_ns)
        loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+2*(self.d > torch.tensor(0.0))*self.d*tmp+self.c*tmp
        real_loss_1 = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)
        real_loss_2 = tmp
        return loss, real_loss_1, real_loss_2
        #real_loss = torch.mean((f_ps-self.a)**2)+torch.mean((f_ns-self.b)**2)+(torch.maximum(tmp,torch.tensor(0.0)))**2
        #return loss, real_loss

    
class PSQLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PSQLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = difference - self.margin
        difference = difference ** 2
        #loss = difference.mean()  
        loss = difference
        return loss

    
class PSMLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PSMLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        y_pred = torch.clamp(y_pred, min=-3, max=3)  
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = 1 + torch.exp(self.margin*difference)
        difference = 1./difference
        loss = difference.mean()  
        return loss

    
class PHLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PHLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = torch.maximum(self.margin - difference, torch.tensor(0.0))
        loss = difference.mean()  
        return loss

    
class PBHLoss(torch.nn.Module):
    def __init__(self, r=1.0, b=2.0): 
        super(PBHLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.r = r
        self.b = b

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        loss = torch.maximum(self.r - difference, self.b * (difference - self.r))
        loss = torch.maximum(-self.b * (self.r + difference) + self.r, loss)
        loss = loss.mean()  
        return loss

    
class PSHLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PSHLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        #print(y_pred)
        #print(y_true)
        f_ps = y_pred[y_true == 1].view(-1, 1)
        f_ns = y_pred[y_true == 0].view(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = torch.maximum(self.margin - difference, torch.tensor(0.0))
        difference = difference ** 2
        loss = difference.mean()  
        #loss = difference
        return loss

class PSH(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PSH, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].view(-1, 1)
        f_ns = y_pred[y_true == 0].view(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = torch.maximum(self.margin - difference, torch.tensor(0.0))
        difference = difference ** 2
        #loss = difference.mean()  
        loss = difference
        return loss
 
 
class PLLoss(torch.nn.Module):
    def __init__(self, margin=1.0): 
        super(PLLoss, self).__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.margin = margin

    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)
        if len(f_ns) == 0 or len(f_ps) == 0:
            print('switching to logitstic loss')
            criterion = F.binary_cross_entropy_with_logits  # with sigmoid
            return criterion(y_pred, y_true)
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        difference = 1 + torch.exp(-self.margin*difference)
        difference = torch.log(difference)
        loss = difference.mean()  
        return loss




class APLoss_SH(torch.nn.Module):
    def __init__(self, data_len=None, margin=1.0,  beta=0.99, batch_size=128):
        """
        AP Loss with squared-hinge function: a novel loss function to directly optimize AUPRC.
    
        inputs:
            margin: margin for squred hinge loss, e.g., m in [0, 1]
            beta: factors for moving average, which aslo refers to gamma in the paper
        outputs:
            loss  
        Reference:
            Qi, Q., Luo, Y., Xu, Z., Ji, S. and Yang, T., 2021. 
            Stochastic Optimization of Area Under Precision-Recall Curve for Deep Learning with Provable Convergence. 
            arXiv preprint arXiv:2104.08736.
        Link:
            https://arxiv.org/abs/2104.08736
        """
        super(APLoss_SH, self).__init__()
        self.u_all = torch.tensor([0.0]*data_len).view(-1, 1).cuda()
        self.u_pos = torch.tensor([0.0]*data_len).view(-1, 1).cuda()
        self.margin = margin
        self.beta = beta

    def forward(self, y_pred, y_true, index_s): 

        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)

        f_ps = f_ps.reshape(-1)
        f_ns = f_ns.reshape(-1)

        vec_dat = torch.cat((f_ps, f_ns), 0)
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.reshape(-1, 1)

        neg_mask = torch.ones_like(mat_data)
        neg_mask[:, 0:f_ps.size(0)] = 0

        pos_mask = torch.zeros_like(mat_data)
        pos_mask[:, 0:f_ps.size(0)] = 1

        neg_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * neg_mask
        pos_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * pos_mask
        loss = pos_loss + neg_loss

        if f_ps.size(0) == 1:
            self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.mean())
            self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (loss.mean())
        else:
            self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (loss.mean(1, keepdim=True))
            self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.mean(1, keepdim=True))

        p = (self.u_pos[index_s] - (self.u_all[index_s]) * pos_mask) / (self.u_all[index_s] ** 2)

        p.detach_()
        loss = torch.sum(p * loss)
        loss = loss.mean()
        return loss
 
class BCELoss(torch.nn.Module):
    """
    Cross Entropy Loss with Sigmoid Function
    Reference: 
        https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
    """
    def __init__(self):
        super(BCELoss, self).__init__()
        self.criterion = F.binary_cross_entropy_with_logits  # with sigmoid

    def forward(self, y_pred, y_true):
        y_pred = y_pred.reshape(-1, 1)
        y_true = y_true.reshape(-1, 1)
        return self.criterion(y_pred.cuda(), y_true.type(torch.float).cuda())
    
class CrossEntropyLoss(torch.nn.Module):
    """
    Cross Entropy Loss with Sigmoid Function
    Reference: 
        https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
    """
    def __init__(self):
        super(CrossEntropyLoss, self).__init__()
        self.criterion = F.binary_cross_entropy_with_logits  # with sigmoid
        #self.logsoftmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, y_pred, y_true):
        #y_pred = y_pred.reshape(-1, 1)
        #y_true = y_true.reshape(-1)
        #print(y_pred)
        #print(y_true)
        return self.criterion(y_pred, y_true.float().cuda()) 
        #y_pred = self.logsoftmax(y_pred)
        #y_pred = F.softmax(y_pred,dim=1)
        #y_pred = torch.log(y_pred)
        #try:
        #out = -y_true.cuda()*y_pred
        #except:
        #  out = -y_true*y_pred
        #print(y_true.shape)
        #print(y_pred.shape)
        #exit()
        #out = torch.sum(out,dim=1)
        #out = torch.mean(out)
        #return out
    
    
class FocalLoss(torch.nn.Module):
    """
    Focal Loss
    Reference: 
        https://amaarora.github.io/2020/06/29/FocalLoss.html
    """
    def __init__(self, alpha=.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

    
    
