import torch


def proj_norm(y):
    B, C = y.shape
    u, arg_u = torch.sort(y, dim=-1, descending=True)

    y_norm = torch.zeros(y.shape).cuda()
    for i in range(B):
        rho = 0
        for j in range(C):
            if u[i, j] + 1/(j+1)*(1 - torch.sum(u[i, 0:j+1])) > 0:
                rho = j+1
        lambda_val = 1.0 / rho * (1 - torch.sum(u[i, 0:rho]))
        x = torch.zeros([C]).cuda()
        for j in range(C):
            x[j] = max(y[i, j] + lambda_val, 0)
        y_norm[i] = x
    return y_norm


def celoss(y_hat, y, reduction='sum'):
    y_hat = torch.clip(y_hat, 1e-7, 1)
    if reduction == 'sum':
        return (-y * y_hat.log()).sum()
    elif reduction == 'mean':
        return (-y * y_hat.log()).sum(dim=-1).mean()