from math import sqrt
import numpy as np
import torch as tc
import torch.nn as nn
from src.utils.printer import dprint

def dlr(x, y):
    x_sorted, ind_sorted = x.sort(dim=1)
    ind = (ind_sorted[:, -1] == y).float()
        
    return -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)


class MixUpLoss(tc.nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = tc.nn.CrossEntropyLoss()
        
    def forward(self, x, y):
        if isinstance(y, tuple):
            y1, y2, lam = y
            return self.criterion(x, y1)*lam + self.criterion(x, y2)*(1.0-lam)
        else:
            return self.criterion(x, y)




