import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def mae(input, target, weight=None):
    l1_loss = nn.L1Loss(reduce=False)
    loss = l1_loss(input, target)
    if weight is not None:
        loss = weight * loss
    return torch.mean(loss)


def mse(input, target, weight=None):
    l2_loss = nn.MSELoss(reduce=False)
    loss = l2_loss(input, target)
    if weight is not None:
        loss = weight * loss
    return torch.mean(loss)


def ce(input, target, weight=None):
    ce = nn.CrossEntropyLoss(reduce=False)
    loss = ce(input, target)
    if weight is not None:
        loss = loss * weight
    return torch.mean(loss)


def bce(input, target, weight=None):
    bce = nn.BCELoss(reduce=False)
    loss = bce(input, target)
    if weight is not None:
        loss = loss * weight
    return torch.mean(loss)


def FL(inputs, targets, alpha, gamma, weight_t=None):
    loss = F.binary_cross_entropy(inputs, targets, reduce=False)
    weight = torch.ones(inputs.shape, dtype=torch.float).to(inputs.device)
    weight[targets == 1] = float(alpha)
    loss_w = F.binary_cross_entropy(inputs, targets, weight=weight, reduce=False)
    pt = torch.exp(-loss)
    weight_gamma = (1 - pt) ** gamma
    if weight_t is not None:
        weight_gamma = weight_gamma * weight_t
    F_loss = torch.mean(weight_gamma * loss_w)
    return F_loss
