import torch
import torch.nn as nn
import torch.nn.functional as F


class ConfidenceRanking(nn.Module):
    def __init__(self, T, r1=1, r2=1, margin=2.0, type='log'):
        super(ConfidenceRanking, self).__init__()
        self.T = T
        self.type = type
        self.margin = margin
        self.r1 = r1
        self.r2 = r2

    def forward(self, y_s, y_t, target):
        bs, n_cls = y_s.size()
        gt = F.one_hot(target, n_cls).cuda()

        # logit-based logistic loss
        # stu_cr_tea_pt = torch.softmax((y_s - y_t) / self.T, dim=1)
        stu_cr_tea_pt = torch.sigmoid((y_s - y_t) / self.T)
        stu_cr_tea_pt = stu_cr_tea_pt.clamp(1e-10, 1-1e-10)
        log_loss = -1.0 * gt * torch.log(stu_cr_tea_pt) - (1 - gt) * torch.log(1 - stu_cr_tea_pt)
        log_loss *= self.T ** 2
        
        # logit-based hinge loss
        gt[gt == 0] = -1
        hinge_loss = F.relu(y_s - (y_t + gt * self.margin))

        # logit-based mse loss
        #gt[gt == 0] = -1
        mse_loss = (y_s - (y_t + gt * self.margin)) ** 2    
        # mse_loss = (y_s.sigmoid() - (y_t.sigmoid() + gt * 0.2)) ** 2
        # return hinge_loss.mean()
        # return mse_loss.mean()
        return self.r1 * log_loss.mean() + self.r2 * mse_loss.mean() + hinge_loss.mean()


class RelationalConfidenceRanking(nn.Module):
    def __init__(self, T, r1=1, r2=1, margin=2, type='log'):
        super(RelationalConfidenceRanking, self).__init__()
        self.T = T
        self.type = type
        self.margin = margin
        self.r1 = r1
        self.r2 = r2

    def forward(self, y_s, y_t, target):
        bs, n_cls = y_s.size()
        gt = F.one_hot(target, n_cls).cuda()

        cls_diff = gt.view(bs, 1, n_cls) - gt.view(bs, n_cls, 1) # 0, 1, -1

        y_s_diff = y_s.view(bs, 1, n_cls) - y_s.view(bs, n_cls, 1)
        y_t_diff = y_t.view(bs, 1, n_cls) - y_t.view(bs, n_cls, 1)

        
        # diff_pt = torch.softmax((y_s_diff - y_t_diff) / self.T, dim=2)
        diff_pt = torch.sigmoid((y_s_diff - y_t_diff) / self.T)
        diff_pt = diff_pt.clamp(1e-10, 1-1e-10)
        # pn dist
        pn_diff = diff_pt[cls_diff == 1]
        pn_diff_size = pn_diff.view(-1).size(0)
        pn_log_loss = -torch.log(pn_diff) * self.T ** 2

        # np dist
        np_diff = diff_pt[cls_diff == -1]
        np_diff_size = np_diff.view(-1).size(0)
        np_log_loss = -torch.log(1 - np_diff) * self.T ** 2

        log_loss = pn_log_loss.sum() + np_log_loss.sum() #+ pp_nn_loss.mean()
        log_loss /= (np_diff_size + pn_diff_size)

        #square loss
        diff = y_s_diff - y_t_diff
        pn_diff = diff[cls_diff == 1]

        square_loss = (self.margin - pn_diff) ** 2
        square_loss = square_loss.mean()
        return self.r1 * log_loss #+ 0.01 * square_loss
