import torch
import torch.nn as nn


def hinge_loss(t, y_j, y_k):
    zero_mat = torch.zeros(t.size()).cuda()
    dis = torch.abs(y_j - y_k)
    # dis = torch.pow(y_j - y_k, 2)
    loss = torch.max(dis-t, zero_mat)
    # loss = torch.pow((dis-t), 2)
    return loss

def ranking_loss(y_hat, y, pi, t):
    b, q = y_hat.shape
    rank_loss = 0.0
    for j in range(q-1):
        for k in range(j+1, q):
            gate = (torch.abs(y[:, j] > y[:, k]) > t)
            rank_loss += gate * (1 / pi[j, k]) * (y[:, j] > y[:, k]).mul(hinge_loss(y_hat[:, j] - y_hat[:, k], y[:, j], y[:, k]))
            rank_loss += gate * (1 / pi[k, j]) * (y[:, j] < y[:, k]).mul(hinge_loss(y_hat[:, k] - y_hat[:, j], y[:, j], y[:, k]))
    return (rank_loss.sum() / b) / (q * (q - 1)) * 2