import torch
import torch.nn.functional as F

from utility.args import Args

Args.add_argument("--label_smoothing", type=float, help="Smoothing for smooth_crossentropy")
Args.add_argument("--ECE_label_smoothing", type=float, help="Smoothing for equal_crossentropy")
Args.add_argument("--equalLoss", type=str, choices=['cross', 'punishHit', 'inverted_cross'], help=f"Which 'equal loss' (regularizing loss) to choose.")

def crossentropy(pred, targets, smoothing = 0):
    n_class = pred.size(1)

    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=targets.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)

    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

def punishHitLoss(pred, subTargets, **kwargs):
    prob = F.softmax(pred, dim = 1)
    epsilon = 1e-5
    return -1 * torch.log(1 + epsilon - prob[range(0, subTargets.shape[0]), subTargets])

def equal_crossentropy(pred, targets, **kwargs):
    if Args.singleHead:
        return -torch.sum(F.log_softmax(pred, dim = 1), dim = -1) / Args.subLabels
    else:
        #indices of subClass labels
        #@TODO int 64 ?! 
        indices = (targets.unsqueeze(1) * Args.subLabels).expand((targets.shape[0],  Args.subLabels)) + torch.arange(0,Args.subLabels, dtype = torch.int64).to(targets.device)
        one_hot = torch.full_like(pred, fill_value= Args.ECE_label_smoothing / (pred.size(1) - Args.subLabels ))
        one_hot.scatter_(dim=1, index=indices, value=(1- Args.ECE_label_smoothing)/Args.subLabels)
        log_prob = F.log_softmax(pred, dim=1)

        return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

def inverted_equal_crossentropy(pred, targets, **kwargs):
    if Args.singleHead:
        return -1*torch.sum(torch.log(1-F.softmax(pred, dim = 1)), dim = -1)
    else:
        raise NotImplementedError("EqualLoss 'inverted_cross' not Implemented for multiHead")

def getEqualLoss():
    match Args.equalLoss:
        case "punishHit":
            equalLoss = punishHitLoss
        case "cross":
            equalLoss = equal_crossentropy
        case "inverted_cross":
            equalLoss = inverted_equal_crossentropy

    return equalLoss
