import torch 
import math
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer, required 

class AUCSquare_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        pred = torch.sigmoid(pred)
        
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        out = self.margin - (tp - fp)
        out = torch.mean(out ** 2) 
        return out


class IF_AUCSquare_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features, margin=1, group_id=None):
        pred = torch.sigmoid(pred)
        
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        out = self.margin - (tp - fp)
        if self.use_margin:
            out_margin = out[out < self.margin]
            out = torch.mean(out_margin ** 2)
        else:
            out = torch.mean(out ** 2) 
        return out


class AUCHinge_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        # pred = torch.sigmoid(pred)
        
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        out = self.margin - (tp - fp)
        out = torch.mean(torch.nn.functional.relu(out))
        return out

   
class AUCExponential_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)
        
        out = tp - fp
        out = torch.mean(torch.exp(-1.0 * out)) 
        return out  

class AUCLogistic_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):

        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        tp_pred = torch.sigmoid(tp)
        fp_pred = torch.sigmoid(fp)

        alpha = 1 - tp_pred

        out = tp - fp
        out = torch.mean(torch.log(1 + torch.exp(-1.0 * out))) 
        return out  


class AUCComp_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2
        self.backend = 'ce'
        self.pos_scale = 1
        self.neg_scale = 1
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        if self.backend == 'ce':
            self.backend = 'auc_l'
            return self.bce(pred, label)
        else:
            self.backend = 'ce'
            true = label == 1
            false = label == 0
            true_pred = pred[true]
            false_pred = pred[false]
            
            tp = true_pred.unsqueeze(1)
            fp = false_pred.unsqueeze(0)

            sp = torch.exp(-self.pos_scale * tp).sum()
            sn = torch.exp(self.neg_scale * fp).sum()

            loss = torch.log(sp) / self.pos_scale + torch.log(sn) / self.neg_scale + self.bce(pred, label)
            return loss  

class AUCSimple_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

        self.pos_scale = 1
        self.neg_scale = 1

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)            

        sp = torch.exp(-self.pos_scale * tp).sum()
        sn = torch.exp(self.neg_scale * fp).sum()

        loss = torch.log(sp) / self.pos_scale + torch.log(sn) / self.neg_scale
        return loss  


class AUCPR_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

        self.pos_scale = 1
        self.neg_scale = 1
        self.lbd = 0.2

        self.alpha = 0.4
        self.beta = self.alpha - 0.2

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]

        # sorted_tp, _ = torch.sort(true_pred, dim=0)
        # sorted_fp, _ = torch.sort(false_pred, dim=0)
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)            

        sp = torch.exp(-self.pos_scale * tp).sum()
        sn = torch.exp(self.neg_scale * fp).sum()

        auc_loss = torch.log(sp) / self.pos_scale + torch.log(sn) / self.neg_scale
        rank_loss = torch.mean(torch.nn.functional.relu(self.margin - torch.sigmoid(tp))) + torch.mean(torch.nn.functional.relu(torch.sigmoid(fp)))
        return auc_loss + self.lbd * rank_loss  
        


class AUCCircle_loss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        n_neg = false_pred.size(0)
        n_pos = true_pred.size(0)

        tp_avg = torch.exp(-tp).sum()
        fp_avg = torch.exp(fp).sum()

        return torch.log(1 + tp_avg * fp_avg)

class AUCUnivariate_Loss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        n_neg = false_pred.size(0)
        n_pos = true_pred.size(0)

        tp_avg = torch.exp(-tp).sum()
        fp_avg = torch.exp(fp).sum()

        return tp_avg * fp_avg / (n_pos * n_neg)


class AUCPointloss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2

        self.register_buffer('pos_margin', torch.tensor([0.2]))
        self.register_buffer('neg_margin', torch.tensor([0.]))

        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        # pred  = torch.sigmoid(pred)
        pred_sorted, index = torch.sort(pred, dim=0)
        label_sorted = label[index]

        pos_index       = torch.where(label_sorted == 1)[0]
        first_pos_index = pos_index[0]
        neg_index       = torch.where(label_sorted == 0)[0]
        final_neg_index = neg_index[-1]

        pred_range  = pred_sorted[first_pos_index:final_neg_index]
        label_range = label_sorted[first_pos_index:final_neg_index]

        true_pred = pred_range[label_range == 1]
        false_pred = pred_range[label_range == 0]

        tp_avg = torch.exp(self.pos_margin - true_pred).sum()
        fp_avg = torch.exp(false_pred - self.neg_margin).sum()
        loss = torch.log(1 + tp_avg * fp_avg) + self.bce(pred, label)
        return loss

class AUCPointlossV2(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        
        self.register_buffer('pos_margin', torch.tensor([0.2]))
        self.register_buffer('neg_margin', torch.tensor([0.]))

        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        # pred  = torch.sigmoid(pred)
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        sorted_tp, _ = torch.sort(true_pred, dim=0)
        sorted_fp, _ = torch.sort(false_pred, dim=0)

        fp_gt_tp_values = torch.searchsorted(sorted_tp, sorted_fp + margin, right=True)
        tp_gt_fp_values = torch.searchsorted(sorted_fp, sorted_tp - margin, right=False)

        pos_dot_neg     = fp_gt_tp_values.sum()
        tp_le_fp_values = sorted_fp.size(0) - tp_gt_fp_values

        alpha_tp = tp_le_fp_values * pred.size(0) / pos_dot_neg
        alpha_fp = fp_gt_tp_values * pred.size(0) / pos_dot_neg
        
        tp_loss = -torch.log(torch.sigmoid(sorted_tp)) - torch.sigmoid(sorted_tp) * alpha_tp
        fp_loss = -torch.log(1-torch.sigmoid(sorted_fp)) + torch.sigmoid(sorted_fp) * alpha_fp
        return (tp_loss.sum() + fp_loss.sum()) / pred.size(0)

class AUCPointwiseHinge_loss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true  = label == 1
        false = label == 0
        label[label == 0] = -1

        true_pred  = pred[true]
        false_pred = pred[false]
        sorted_tp, _ = torch.sort(true_pred, dim=0)
        sorted_fp, _ = torch.sort(false_pred, dim=0)

        fp_gt_tp_values = torch.searchsorted(sorted_tp, sorted_fp + margin, right=True)
        tp_gt_fp_values = torch.searchsorted(sorted_fp, sorted_tp - margin, right=False)

        pos_dot_neg     = sorted_fp.size(0) * sorted_tp.size(0)
        tp_le_fp_values = sorted_fp.size(0) - tp_gt_fp_values

        alpha_tp = tp_le_fp_values / sorted_fp.size(0)
        alpha_fp = fp_gt_tp_values / sorted_tp.size(0)
        
        fp_loss = torch.log(torch.exp(alpha_fp * sorted_fp) + torch.exp((1+alpha_fp) * sorted_fp)).sum()
        tp_loss = torch.log(torch.exp(-alpha_tp * sorted_tp) + torch.exp(-(1+alpha_tp) * sorted_tp)).sum()
        loss = (fp_loss + tp_loss) / pred.size(0)
        return loss
        logloss = torch.log(1+torch.exp(false_pred)).sum() + torch.log(1+torch.exp(-true_pred)).sum()
        logloss /= pred.size(0)
        return logloss
        

class AUCPolyloss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2

        self.register_buffer('pos_margin', torch.tensor([0.2]))
        self.register_buffer('neg_margin', torch.tensor([0.]))

        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        n_neg = false_pred.size(0)
        n_pos = true_pred.size(0)

        tp_avg = torch.exp(-true_pred).mean()
        fp_avg = torch.exp(false_pred).mean()

        loss = self.bce(pred, label)
        poly_loss = 1 - 1 / (1 + tp_avg * fp_avg)
        return loss + poly_loss


class AUCTP_at_FP_loss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2

        self.register_buffer('pos_margin', torch.tensor([0.2]))
        self.register_buffer('neg_margin', torch.tensor([0.]))

        self.bce = nn.BCEWithLogitsLoss()

    def tp_at_fp_loss(self, sorted_tp, sorted_fp, fp_rate=0.25):

        margin_index = int((1 - fp_rate) * sorted_fp.size(0))
        margin = sorted_fp[margin_index]

        pos_loss = torch.mean(torch.nn.functional.relu(margin - sorted_tp))
        neg_loss = torch.mean(torch.nn.functional.relu(sorted_fp - margin))
        return pos_loss + neg_loss

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]

        sorted_tp, _ = torch.sort(true_pred, dim=0)
        sorted_fp, _ = torch.sort(false_pred, dim=0)

        sorted_tp = torch.sigmoid(sorted_tp)
        sorted_fp = torch.sigmoid(sorted_fp)

        l1 = self.tp_at_fp_loss(sorted_tp, sorted_fp, 0.25)
        l2 = self.tp_at_fp_loss(sorted_tp, sorted_fp, 0.5)
        l3 = self.tp_at_fp_loss(sorted_tp, sorted_fp, 0.75)

        return l1 + l2 + l3


        
       
class Polyloss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()

    def forward(self, pred, label):
        loss = label * (1 - torch.sigmoid(pred)) + (1.0 - label) * torch.sigmoid(pred)
        return loss.mean()


def binarize_and_smooth_labels(T, nb_classes, smoothing_const = 0.1):
    # Optional: BNInception uses label smoothing, apply it for retraining also
    # "Rethinking the Inception Architecture for Computer Vision", p. 6
    import sklearn.preprocessing
    T = T.cpu().numpy()
    T = sklearn.preprocessing.label_binarize(
        T, classes = range(0, nb_classes)
    )
    T = T * (1 - smoothing_const)
    T[T == 0] = smoothing_const / (nb_classes - 1)
    T = torch.FloatTensor(T).cuda()
    return T


class AUCDML_Loss(nn.Module):
    def __init__(self, proxy_size, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.bce = nn.BCEWithLogitsLoss()
        self.proxies = nn.Parameter(torch.randn(2, proxy_size) / 8)
        self.smoothing_const = 0.1


    def forward(self, pred, label, features=None, pos_proxy=None, neg_proxy=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]

        pos_feature = F.normalize(features[true], p=2, dim=-1)
        neg_feature = F.normalize(features[false], p=2, dim=-1)

        P = F.normalize(self.proxies, p=2, dim=-1)
        X = F.normalize(features, p=2, dim=-1)
        D = torch.cdist(X, P) ** 2
        T = binarize_and_smooth_labels(label, len(P), self.smoothing_const)
        # note that compared to proxy nca, positive included in denominator
        metric_loss = torch.sum(-T * F.log_softmax(-D, -1), -1).mean()
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        n_neg = false_pred.size(0)
        n_pos = true_pred.size(0)

        logloss = self.bce(pred, label)
        return metric_loss
        



     
class AUCSmoothHinge_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)
        
        out_pred = torch.sigmoid(tp) - torch.sigmoid(fp)
        out_margin = out_pred[out_pred < self.margin]
        loss = torch.log(torch.sum(torch.exp(self.margin - out_margin)) + 1e-10)
        return loss
     

class IF_AUCExponential_loss(nn.Module):

    def __init__(self, use_margin=True, margin=0.1, epsilon=1e-3, norm=1) -> None:
        super().__init__()
        self.margin     = margin
        self.use_margin = use_margin
        self.epsilon    = epsilon
        self.norm       = norm

    def forward(self, logit, label, features, margin=1, group_id=None):

        true = label == 1
        false = label == 0
        true_pred =logit[true]
        false_pred = logit[false]

        factor = (torch.sigmoid(logit) - label) * features
        factor = torch.norm(factor, p=self.norm)

        out = tp - fp
        loss = torch.exp(-1.0 * out)
        return torch.mean(loss)       



class AUCCosloss(nn.Module):
    def __init__(self, use_margin=True, margin=0.1) -> None:
        super().__init__()
        self.margin = margin
        self.use_margin = use_margin
        self.gamma = 2

        self.register_buffer('pos_margin', torch.tensor([0.2]))
        self.register_buffer('neg_margin', torch.tensor([0.]))

        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, label, features=None, margin=1, group_id=None):
        true = label == 1
        false = label == 0
        true_pred = pred[true]
        false_pred = pred[false]
        
        tp = true_pred.unsqueeze(1)
        fp = false_pred.unsqueeze(0)

        n_neg = false_pred.size(0)
        n_pos = true_pred.size(0)

        dd = tp * fp
        ll = torch.exp(dd).mean()
        return ll





class FocalLoss(torch.nn.Module):
    """
    Focal Loss
    Reference: 
        https://amaarora.github.io/2020/06/29/FocalLoss.html
    """
    def __init__(self, alpha=.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()



class SimilarityLoss(nn.Module):

    def __init__(self, margin=0.2) -> None:
        super().__init__()
        self.margin=0.2

    def forward(self, pred, label):
        # pred: N x N
        # label: N x 1
        label = label.view(-1, 1)
        label_matrix = torch.mm(label, label.transpose(0,1)) * 2 - 1 
        label_matrix -= torch.eye(label.size(0), dtype=torch.float32).cuda()
        
        sim_matrix = label_matrix * pred 

        positive = (1 == label_matrix) * sim_matrix
        positive_value = positive.mean()

        negative = (-1 == label_matrix) * sim_matrix
        negative_value = negative.mean()

        return torch.nn.functional.relu(self.margin - positive_value - negative_value)

 
class FocalLoss(torch.nn.Module):
    """
    Focal Loss
    Reference: 
        https://amaarora.github.io/2020/06/29/FocalLoss.html
    """
    def __init__(self, alpha=.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class AUCMLoss(torch.nn.Module):
    """
    AUCM Loss with squared-hinge function: a novel loss function to directly optimize AUROC
    
    inputs:
        margin: margin term for AUCM loss, e.g., m in [0, 1]
        imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples
    outputs:
        loss value 
    
    Reference: 
        Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 
        Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 
        International Conference on Computer Vision (ICCV 2021)
    Link:
        https://arxiv.org/abs/2012.03173
    """
    def __init__(self, margin=1.0, imratio=None, device=None):
        super(AUCMLoss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.margin = margin
        self.p = imratio
        # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751
        self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #cuda()
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device,  requires_grad=True).to(self.device) #.cuda()
        self.alpha = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) #.cuda()
        
    def forward(self, y_pred, y_true):
        if self.p is None:
           self.p = (y_true==1).float().sum()/y_true.shape[0]   
     
        y_pred = y_pred.reshape(-1, 1) # be carefull about these shapes
        y_true = y_true.reshape(-1, 1) 
        loss = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \
                    self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float())   + \
                    2*self.alpha*(self.p*(1-self.p)*self.margin + \
                    torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \
                    self.p*(1-self.p)*self.alpha**2
        return loss


class PESG(torch.optim.Optimizer):
    '''
    Proximal Epoch Stochastic Gradient (PESG) 
    Reference: 
        Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 
        Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 
        International Conference on Computer Vision (ICCV 2021)
    Link:
        https://arxiv.org/abs/2012.03173
    '''
    def __init__(self, 
                 model, 
                 a=None, 
                 b=None, 
                 alpha=None, 
                 imratio=0.1, 
                 margin=1.0, 
                 lr=0.1, 
                 gamma=500, 
                 clip_value=1.0, 
                 weight_decay=1e-5, 
                 device = None,
                 **kwargs):
       
        assert a is not None, 'Found no variable a!'
        assert b is not None, 'Found no variable b!'
        assert alpha is not None, 'Found no variable alpha!'
        
        self.p = imratio
        self.margin = margin
        self.model = model
        
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:  
            self.device = device   
        
        self.lr = lr
        self.gamma = gamma
        self.clip_value = clip_value
        self.weight_decay = weight_decay
        
        self.a = a 
        self.b = b 
        self.alpha = alpha 
    
        # TODO! 
        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        
        self.T = 0
        self.step_counts = 0
    
        def get_parameters(params):
            for p in params:
                yield p
        self.params = get_parameters(list(model.parameters())+[a,b])
        self.defaults = dict(lr=self.lr, 
                             margin=margin, 
                             gamma=gamma, 
                             p=imratio, 
                             a=self.a, 
                             b=self.b,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             weight_decay=weight_decay,
                             model_ref = self.model_ref,
                             model_acc = self.model_acc
                             )
        
        super(PESG, self).__init__(self.params, self.defaults)
     
    def init_model_ref(self):
         self.model_ref = []
         for var in list(self.model.parameters())+[self.a, self.b]: 
            self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
         return self.model_ref
     
    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters())+[self.a, self.b]: 
            self.model_acc.append(torch.zeros(var.shape, dtype=torch.float32,  device=self.device, requires_grad=False).to(self.device)) 
        return self.model_acc
    
    @property    
    def optim_steps(self):
        return self.step_counts
    
    @property
    def get_params(self):
        return list(self.model.parameters())
    
    def update_lr(self, lr):
        self.param_groups[0]['lr']=lr

    @torch.no_grad()
    def step(self):
        """Performs a single optimization step.
        """
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            self.lr =  group['lr']
            
            p = group['p']
            gamma = group['gamma']
            m = group['margin']
           
            model_ref = group['model_ref']
            model_acc = group['model_acc']

            a = group['a']
            b = group['b']
            alpha = group['alpha']
            
            # updates
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue  
                p.data = p.data - group['lr']*( torch.clamp(p.grad.data , -clip_value, clip_value) + 1/gamma*(p.data - model_ref[i].data) + weight_decay*p.data)
                model_acc[i].data = model_acc[i].data + p.data

            alpha.data = alpha.data + group['lr']*(2*(m + b.data - a.data)-2*alpha.data)
            alpha.data  = torch.clamp(alpha.data,  0, 999)

        self.T += 1  
        self.step_counts += 1

    def zero_grad(self):
        self.model.zero_grad()
        self.a.grad = None
        self.b.grad = None
        self.alpha.grad =None
        
    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor
            print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.T))
        print ('Updating regularizer @ T=%s!'%(self.T))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data/self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,  requires_grad=False).to(self.device)
        self.T = 0
        
  
 
class CompositionalLoss(torch.nn.Module):
    """  
        Compositional AUC Loss: a novel loss function to directly optimize AUROC
        inputs:
            margin: margin term for AUCM loss, e.g., m in [0, 1]
            imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples
        outputs:
            loss  
        Reference:
            @inproceedings{
                            yuan2022compositional,
                            title={Compositional Training for End-to-End Deep AUC Maximization},
                            author={Zhuoning Yuan and Zhishuai Guo and Nitesh Chawla and Tianbao Yang},
                            booktitle={International Conference on Learning Representations},
                            year={2022},
                            url={https://openreview.net/forum?id=gPvB4pdu_Z}
                            }
    """
    def __init__(self, imratio=None,  margin=1, backend='ce', device=None):
        super(CompositionalLoss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.margin = margin
        self.p = imratio
        self.a = torch.ones(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 
        self.b = torch.zeros(1, dtype=torch.float32, device=self.device,  requires_grad=True).to(self.device) 
        self.alpha = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 
        self.L_AVG = F.binary_cross_entropy_with_logits  # with sigmoid
        self.backend = 'ce'  #TODO: 

    def forward(self, y_pred, y_true):
        if len(y_pred) == 1:
            y_pred = y_pred.reshape(-1, 1)
        if len(y_true) == 1:
            y_true = y_true.reshape(-1, 1)
        if self.backend == 'ce':
           self.backend = 'auc'
           return self.L_AVG(y_pred, y_true)
        else:
           self.backend = 'ce'
           if self.p is None:
              self.p = (y_true==1).float().sum()/y_true.shape[0] 
           y_pred = torch.sigmoid(y_pred)
           self.L_AUC = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \
                      self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float())   + \
                      2*self.alpha*(self.p*(1-self.p)*self.margin + \
                      torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \
                      self.p*(1-self.p)*self.alpha**2
           return self.L_AUC 



class PDSCA(torch.optim.Optimizer):
    """
    Reference:
    @inproceedings{
                    yuan2022compositional,
                    title={Compositional Training for End-to-End Deep AUC Maximization},
                    author={Zhuoning Yuan and Zhishuai Guo and Nitesh Chawla and Tianbao Yang},
                    booktitle={International Conference on Learning Representations},
                    year={2022},
                    url={https://openreview.net/forum?id=gPvB4pdu_Z}
                    }
    """
    def __init__(self, 
                 model, 
                 a=None, 
                 b=None, 
                 alpha=None, 
                 margin=1.0, 
                 lr=0.1, 
                 lr0=None,
                 gamma=500,
                 beta1=0.99,
                 beta2=0.999,
                 clip_value=1.0, 
                 weight_decay=1e-5, 
                 device = 'cuda',
                 **kwargs):
       
        # TODO: support a,b,alpha is None
        assert a is not None, 'Found no variable a!'
        assert b is not None, 'Found no variable b!'
        assert alpha is not None, 'Found no variable alpha!'
        
        self.margin = margin
        self.model = model
        
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device 
        
        if lr0 is None:
           lr0 = lr
       
        self.lr = lr
        self.lr0 = lr0
        self.gamma = gamma
        self.clip_value = clip_value
        self.weight_decay = weight_decay
        self.beta1 = beta1
        self.beta2 = beta2
        
        self.a = a 
        self.b = b 
        self.alpha = alpha 
            
        # TODO: 
        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()

        self.T = 0
        self.steps = 0
        self.backend='ce' # TODO

        def get_parameters(params):
            for p in params:
                yield p
        if self.a is not None or self.b is not None:
           self.params = get_parameters(list(model.parameters())+[self.a, self.b])
        else:
           self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr=self.lr, 
                             lr0=self.lr0,
                             margin=margin, 
                             gamma=gamma, 
                             a=self.a, 
                             b=self.b,
                             alpha=self.alpha,
                             clip_value=self.clip_value,
                             weight_decay=self.weight_decay,
                             beta1=self.beta1,
                             beta2=self.beta2,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc)
        
        super(PDSCA, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(PDSCA, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
         self.model_ref = []
         for var in list(self.model.parameters())+[self.a, self.b]: 
            if var is not None:
               self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
         return self.model_ref
     
    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters())+[self.a, self.b]: 
            if var is not None:
               self.model_acc.append(torch.zeros(var.shape, dtype=torch.float32,  device=self.device, requires_grad=False).to(self.device)) 
        return self.model_acc
    
    @property    
    def optim_steps(self):
        return self.steps
    
    @property
    def get_params(self):
        return list(self.model.parameters())
    
    def update_lr(self, lr):
        self.param_groups[0]['lr']=lr

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
 
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            self.lr =  group['lr']
            self.lr0 = group['lr0']
            gamma = group['gamma']
            m = group['margin']
            beta1 = group['beta1']
            beta2 = group['beta2']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            a = group['a']
            b = group['b']
            alpha = group['alpha']
            
            for i, p in enumerate(group['params']):
                if p.grad is None: 
                   continue  
                d_p = torch.clamp(p.grad.data , -clip_value, clip_value) + 1/gamma*(p.data - model_ref[i].data) + weight_decay*p.data
                if alpha.grad is None: # sgd + moving p. # TODO: alpha=None mode
                    p.data = p.data - group['lr0']*d_p 
                    if beta1!= 0: 
                        param_state = self.state[p]
                        if 'weight_buffer' not in param_state:
                            buf = param_state['weight_buffer'] = torch.clone(p).detach()
                        else:
                            buf = param_state['weight_buffer']
                            buf.mul_(1-beta1).add_(p, alpha=beta1)
                        p.data =  buf.data # Note: use buf(s) to compute the gradients w.r.t AUC loss can lead to a slight worse performance 
                elif alpha.grad is not None: # auc + moving g. # TODO: alpha=None mode
                   if beta2!= 0: 
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(1-beta2).add_(d_p, alpha=beta2)
                        d_p =  buf
                   p.data = p.data - group['lr']*d_p 
                else:
                    NotImplementedError 
                model_acc[i].data = model_acc[i].data + p.data
                
            if alpha is not None: 
               if alpha.grad is not None: 
                  alpha.data = alpha.data + group['lr']*(2*(m + b.data - a.data)-2*alpha.data)
                  alpha.data  = torch.clamp(alpha.data,  0, 999)
              
        self.T += 1        
        self.steps += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
           self.a.grad = None
           self.b.grad = None
        if self.alpha is not None:
           self.alpha.grad = None
        
    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor
            self.param_groups[0]['lr0'] = self.param_groups[0]['lr0']/decay_factor
            print ('Reducing learning rate to %.5f (%.5f) @ T=%s!'%(self.param_groups[0]['lr'], self.param_groups[0]['lr0'], self.steps))
            
        print ('Updating regularizer @ T=%s!'%(self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data/self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,  requires_grad=False).to(self.device)
        self.T = 0


class APLoss(torch.nn.Module):
    def __init__(self, data_len=None, margin=1.0, gamma=0.99, surrogate_loss='squared_hinge', device=None):
        """
        AP Loss with squared-hinge function: a novel loss function to directly optimize AUPRC.
    
        inputs:
            margin: margin for squred hinge loss, e.g., m in [0, 1]
            gamma: factors for moving average
        outputs:
            loss value   
        Reference:
            Qi, Q., Luo, Y., Xu, Z., Ji, S. and Yang, T., 2021. 
            Stochastic Optimization of Area Under Precision-Recall Curve for Deep Learning with Provable Convergence. 
            arXiv preprint arXiv:2104.08736.
        Link:
            https://arxiv.org/abs/2104.08736
        Acknowledgement:
            Gang Li helps clean the code.
        """
        super(APLoss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device           
        self.u_all = torch.tensor([0.0]*data_len).view(-1, 1).to(self.device)
        self.u_pos = torch.tensor([0.0]*data_len).view(-1, 1).to(self.device)
        self.margin = margin
        self.gamma = gamma
        self.surrogate_loss = surrogate_loss

    def forward(self, y_pred, y_true, index_s): 
        if len(y_pred.shape) == 1:
           y_pred = y_pred.reshape(-1,1)
        if len(y_true.shape) == 1:
           y_true = y_true.reshape(-1,1)
        if len(index_s.shape) == 1:
           index_s = index_s.reshape(-1,1)
           
        f_ps = y_pred[y_true == 1].reshape(-1,1)
        index_ps = index_s[y_true == 1].reshape(-1)
        mat_data = y_pred.reshape(-1).repeat(len(f_ps), 1)
        pos_mask = (y_true == 1).reshape(-1)
        
        #print (f_ps.shape, index_ps.shape, mat_data.shape, pos_mask.shape)  # torch.Size([1, 1]) torch.Size([1]) torch.Size([1, 64]) torch.Size([64]
        if self.surrogate_loss == 'squared_hinge':
           sur_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2
        else:
            NotImplementedError 
        
        pos_sur_loss = sur_loss * pos_mask
        self.u_all[index_ps] = (1 - self.gamma) * self.u_all[index_ps] + self.gamma * (sur_loss.mean(1, keepdim=True))
        self.u_pos[index_ps] = (1 - self.gamma) * self.u_pos[index_ps] + self.gamma * (pos_sur_loss.mean(1, keepdim=True))

        ###size of p: len(f_ps)* len(y_pred)
        p = (self.u_pos[index_ps] - (self.u_all[index_ps]) * pos_mask) / (self.u_all[index_ps] ** 2)
        p.detach_()
        loss = torch.mean(p * sur_loss)
        return loss
    
class APLoss_SH_V1(torch.nn.Module):
    def __init__(self, data_len=None, margin=1.0,  beta=0.99, batch_size=128, device=None):
        """
        AP Loss with squared-hinge function: a novel loss function to directly optimize AUPRC.
    
        inputs:
            margin: margin for squred hinge loss, e.g., m in [0, 1]
            beta: factors for moving average, which aslo refers to gamma in the paper
        outputs:
            loss value   
        Reference:
            Qi, Q., Luo, Y., Xu, Z., Ji, S. and Yang, T., 2021. 
            Stochastic Optimization of Area Under Precision-Recall Curve for Deep Learning with Provable Convergence. 
            Conference on Neural Information Processing Systems 2021 (NeurIPS2021)
        Link:
            https://arxiv.org/abs/2104.08736
        """
        super(APLoss_SH_V1, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device           
        self.u_all = torch.tensor([0.0]*data_len).view(-1, 1).to(self.device)
        self.u_pos = torch.tensor([0.0]*data_len).view(-1, 1).to(self.device)
        self.margin = margin
        self.beta = beta

    def forward(self, y_pred, y_true, index_s): 

        index_s = index_s[y_true.squeeze() == 1] # we only need pos indexes
        
        f_ps = y_pred[y_true == 1].reshape(-1, 1)
        f_ns = y_pred[y_true == 0].reshape(-1, 1)

        f_ps = f_ps.reshape(-1)
        f_ns = f_ns.reshape(-1)

        vec_dat = torch.cat((f_ps, f_ns), 0)
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.reshape(-1, 1)

        neg_mask = torch.ones_like(mat_data)
        neg_mask[:, 0:f_ps.size(0)] = 0

        pos_mask = torch.zeros_like(mat_data)
        pos_mask[:, 0:f_ps.size(0)] = 1

        neg_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * neg_mask
        pos_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * pos_mask
        loss = pos_loss + neg_loss

        if f_ps.size(0) == 1:
            self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.mean())
            self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (loss.mean())
        else:
            self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (loss.mean(1, keepdim=True))
            self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.mean(1, keepdim=True))

        p = (self.u_pos[index_s] - (self.u_all[index_s]) * pos_mask) / (self.u_all[index_s] ** 2)

        p.detach_()
        loss = torch.mean(p * loss)
        return loss

class SOAP(torch.optim.Optimizer):
    r"""
    # This is a wrapper of SOAP_ADAM and SOAP_SGD
    """
    def __init__(self, params, lr=required, weight_decay=0, mode='sgd',
                 momentum=0, dampening=0, nesterov=False, # sgd
                 betas=(0.9, 0.999), eps=1e-8, amsgrad=False, # adam
                ):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not isinstance(mode, str):
           raise ValueError("Invalid mode type: {}".format(mode))
             
        try: 
            params = params.parameters()
        except:
            params = params
         
        self.lr = lr
        self.mode = mode.lower()
        defaults = dict(lr=lr, weight_decay=weight_decay,
                        momentum=momentum, dampening=dampening, nesterov=nesterov,
                        betas=betas, eps=eps, amsgrad=amsgrad)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
            
        super(SOAP, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SOAP, self).__setstate__(state)
        for group in self.param_groups:
            if self.mode == 'sgd':
               group.setdefault('nesterov', False)
            elif self.mode == 'adam':
               group.setdefault('amsgrad', False)
            else:
               NotImplementedError
                

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if self.mode == 'sgd':
               weight_decay = group['weight_decay']
               momentum = group['momentum']
               dampening = group['dampening']
               nesterov = group['nesterov']
               self.lr = group['lr']  
               for p in group['params']:
                  if p.grad is None:
                      print(p.shape)
                      continue
                  d_p = p.grad
                  if weight_decay != 0:
                      d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy)
                  if momentum != 0:
                      param_state = self.state[p]
                      if 'momentum_buffer' not in param_state:
                          buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                      else:
                          buf = param_state['momentum_buffer']
                          buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p
                      if nesterov:
                          d_p = d_p.add(buf, alpha=momentum)
                      else:
                          d_p = buf
                  p.add_(d_p, alpha=-group['lr'])

            elif self.mode == 'adam':
                self.lr = group['lr']
                for p in group['params']:
                    if p.grad is None:
                        continue
                    grad = p.grad
                    if grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    amsgrad = group['amsgrad']
                    state = self.state[p]

                    # State initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if amsgrad:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    if amsgrad:
                        max_exp_avg_sq = state['max_exp_avg_sq']
                    beta1, beta2 = group['betas']

                    state['step'] += 1
                    bias_correction1 = 1 - beta1 ** state['step']
                    bias_correction2 = 1 - beta2 ** state['step']

                    if group['weight_decay'] != 0:
                        grad = grad.add(p, alpha=group['weight_decay'])

                    # Decay the first and second moment running average coefficient
                    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                    if amsgrad:
                        # Maintains the maximum of all 2nd moment running avg. till now
                        torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                        # Use the max. for normalizing running avg. of gradient
                        denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                    else:
                        denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

                    step_size = group['lr'] / bias_correction1

                    p.addcdiv_(exp_avg, denom, value=-step_size)
           
        return loss
