import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Union


class Loss(nn.Module):
    def __init__(self, loss_function="WCE"):
        super(Loss, self).__init__()
        self.name = loss_function

    def forward(self, preds, labels):
        if self.name == "WCE":
            final_loss = weighted_bce_loss(preds, labels)
        elif self.name == "Hybrid":
            final_loss = focal_tversky_loss_plus_plus(
                preds.sigmoid(), labels
            ) + 0.0001 * focal_loss(preds, labels)
        else:
            raise NameError

        return final_loss

def weighted_bce_loss(preds, labels):
    beta = 1 - torch.mean(labels)
    weights = 1 - beta + (2 * beta - 1) * labels
    wce_loss = F.binary_cross_entropy_with_logits(preds, labels, weights, reduction="sum")

    return wce_loss


def focal_tversky_loss_plus_plus(
        preds, labels, gamma: float = 2, beta: float = 0.7, delta: float = 0.75
):
    focal_tversky_loss = 0.0
    epsilon = 1e-7
    n = preds.shape[0]
    for i in range(n):
        tp = torch.sum(preds[i, :, :, :] * labels[i, :, :, :])
        fp = torch.sum((preds[i, :, :, :] * (1 - labels[i, :, :, :])) ** gamma)
        fn = torch.sum(((1 - preds[i, :, :, :]) * labels[i, :, :, :]) ** gamma)
        tversky = (tp + (1 - beta) * fp + beta * fn + epsilon) / (tp + epsilon)
        temp_loss = torch.pow(tversky, delta)
        if temp_loss.data.item() > 50.0:
            temp_loss = torch.clamp(temp_loss, max=50.0)

        focal_tversky_loss = focal_tversky_loss + temp_loss

    return focal_tversky_loss


def focal_loss(
        preds, labels, alpha: float = 0.25, gamma: float = 2, reduction: str = "sum"
):
    bce_cross_entropy = F.binary_cross_entropy_with_logits(preds, labels, reduction=reduction)
    pt = torch.exp(-bce_cross_entropy)
    focal_loss = alpha * ((1 - pt) ** gamma) * bce_cross_entropy

    return focal_loss
