import torch
import torch.nn.functional as F


def kl_divergence(alpha, num_classes, device):
    ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
        torch.lgamma(sum_alpha)
        - torch.lgamma(alpha).sum(dim=1, keepdim=True)
        + torch.lgamma(ones).sum(dim=1, keepdim=True)
        - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
        .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
        .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl


def loglikelihood_loss(y, alpha, device):
    y = y.to(device)
    alpha = alpha.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    loglikelihood = loglikelihood_err + loglikelihood_var
    return loglikelihood


def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None, useKL=True):
    y = y.to(device)
    alpha = alpha.to(device)
    loglikelihood = loglikelihood_loss(y, alpha, device=device)

    if not useKL:
        return loglikelihood

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return loglikelihood + kl_div


def edl_loss(func, y, alpha, epoch_num, num_classes, annealing_step, device, useKL=True):
    y = y.to(device)
    alpha = alpha.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)

    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)

    if not useKL:
        return A

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return A + kl_div


def edl_mse_loss(alpha, target, epoch_num, num_classes, annealing_step, device):
    loss = mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device)
    return torch.mean(loss)


def edl_log_loss(alpha, target, epoch_num, num_classes, annealing_step, device):
    loss = edl_loss(torch.log, target, alpha, epoch_num, num_classes, annealing_step, device)
    return torch.mean(loss)


def edl_digamma_loss(alpha, target, epoch_num, num_classes, annealing_step, device):
    loss = edl_loss(torch.digamma, target, alpha, epoch_num, num_classes, annealing_step, device)
    return torch.mean(loss)



def loss_uncertainty(alpha):
    """Uncertainty-weighted loss"""
    S = torch.sum(alpha, dim=1, keepdim=True)
    u = alpha.shape[1] / S  # uncertainty
    p = alpha / S
    return torch.mean(u * (1 - p).sum(dim=1, keepdim=True))


def loss_margin(alpha, margin=1.0):
    """Margin evidence loss"""
    S = torch.sum(alpha, dim=1, keepdim=True)
    p = alpha / S
    max_p, _ = torch.max(p, dim=1, keepdim=True)
    second_max_p = torch.topk(p, 2, dim=1)[0][:, 1:].clone()
    loss = torch.clamp(margin - (max_p - second_max_p), min=0.0)
    return torch.mean(loss)


def loss_calibration(alpha, target):
    """Calibration loss: encourages predicted probability to match true label"""
    S = torch.sum(alpha, dim=1, keepdim=True)
    p = alpha / S
    target_prob = (target * p).sum(dim=1, keepdim=True)
    return torch.mean((target_prob - 1.0) ** 2)



def get_loss(alpha, target, epoch_num, num_classes, annealing_step, device):
    target = F.one_hot(target, num_classes).to(device)
    alpha = alpha.to(device)

    loss_mse = edl_mse_loss(alpha, target, epoch_num, num_classes, annealing_step, device)
    loss_digamma = edl_digamma_loss(alpha, target, epoch_num, num_classes, annealing_step, device)
    loss_unc = loss_uncertainty(alpha)
    loss_margin_val = loss_margin(alpha)
    loss_calib = loss_calibration(alpha, target)

    loss = loss_digamma
    return loss

