import torch 
import math
import torch.nn.functional as F

class pAUCloss(torch.nn.Module):
    def __init__(self, n_pos=100, n_neg=100, k_value=100):
        super(pAUCloss, self).__init__()
        self.n_pos = n_pos
        self.n_neg = n_neg
        self.k_value = k_value
    
    def __individualloss(self, h):
        return torch.log(1 + torch.exp(-h))
    
    def __diff_matrix(self, x, y):
        x_m = x.repeat(y.shape[0], 1)
        x_m = torch.transpose(x_m, 0, 1)
        y_m = y.repeat(x.shape[0], 1)
        return x_m - y_m
    
    def __split_pos_index(self, y):
        index = torch.where(y > 0.1)
        return index
    
    def __split_neg_index(self, y):
        index = torch.where(y < 0.1)
        return index 
    
    def forward(self, y_pred, y_true, lamb):
        index_pos = self.__split_pos_index(y_true)
        index_neg = self.__split_neg_index(y_true)
        
        y_pred_pos = y_pred[index_pos]
        y_pred_neg = y_pred[index_neg]
        
        m_1 = self.__diff_matrix(y_pred_pos, y_pred_neg)
        loss_1 = self.__individualloss(m_1)
        
        n_pos_batch = index_pos[0].shape[0]
        lamb_batch = lamb[index_pos[0]]

        hinge_1 = loss_1 - lamb_batch.reshape((n_pos_batch, 1))
        loss_1[hinge_1 < 0] = 0
        
        batch = y_pred.shape[0]
        loss_total = self.k_value * torch.sum(lamb_batch) / self.n_neg + torch.sum(loss_1)/(batch ** 2)
                              
        return loss_total
        

class pAUC_mini(torch.nn.Module):
    def __init__(self, alpha1, alpha2, num_neg, loss_type = 'logistic'):
        super(pAUC_mini, self).__init__()
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.num_neg = num_neg
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def __individualloss(self, h):
        return torch.log(1 + torch.exp(-h))

    def forward(self, f_ps, f_ns):
        f_ps = f_ps.view(-1)
        f_ns = f_ns.view(-1)

        k1 = math.floor(self.num_neg * self.alpha1)
        k2 = math.ceil(self.num_neg * self.alpha2)
        f_ns_sort, f_ns_ind = torch.sort(f_ns, descending=True)
        vec_dat = f_ns_sort[k1:k2]
        mat_data = vec_dat.repeat(len(f_ps) ,1)

        f_ps = f_ps.view(-1, 1)

        if self.loss_type == 'logistic':

            neg_loss = self.__individualloss(f_ps - mat_data)
            
        loss = neg_loss

        loss = torch.mean(loss)

        return loss

class AUCMLoss(torch.nn.Module):
    def __init__(self, margin=1.0, imratio=None):
        super(AUCMLoss, self).__init__()
        self.margin = margin
        self.p = imratio
        self.a = torch.zeros(1, dtype=torch.float32, device="cuda", requires_grad=True).cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device="cuda", requires_grad=True).cuda()
        self.alpha = torch.zeros(1, dtype=torch.float32, device="cuda", requires_grad=True).cuda()
        
    def forward(self, y_pred, y_true):
        if self.p is None:
           self.p = (y_true==1).float().sum()/y_true.shape[0]   
     
        y_pred = y_pred.reshape(-1, 1)
        y_true = y_true.reshape(-1, 1) 
        loss = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \
                    self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float())   + \
                    2*self.alpha*(self.p*(1-self.p)*self.margin + \
                    torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \
                    self.p*(1-self.p)*self.alpha**2
        return loss
    

class CrossEntropyBinaryLoss(torch.nn.Module):
    def __init__(self):
        super(CrossEntropyBinaryLoss, self).__init__()
        self.criterion = F.binary_cross_entropy_with_logits  # with sigmoid

    def forward(self, y_pred, y_true):
        return self.criterion(y_pred, y_true)