#!/usr/bin/env python
# encoding: utf-8

# filename: loss_utils

import torch
import torch.nn.functional as F


def softmax_cross_entropy_loss(logits: torch.Tensor, target: torch.Tensor):
    loss = -target * F.log_softmax(logits, dim=1)
    return loss.mean()


def focal_edl_loss(func, y, alpha, gamma):
    S = torch.sum(alpha, dim=1, keepdim=True)
    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
    prob = alpha / torch.sum(alpha, dim=1, keepdim=True)  # gt prob
    gt_cls_prob = torch.sum(y * prob, dim=1)
    weight = (1 - gt_cls_prob) ** gamma
    A = A * weight
    return A


def mse_loss(y, alpha, with_bias=True):
    loglikelihood = loglikelihood_loss(y, alpha, with_bias)
    return loglikelihood


def loglikelihood_loss(y, alpha, with_bias=True):
    # print('222')
    # evi_alp0_ = torch.sum(evi_alp_, dim=-1, keepdim=True)
    # gamma1_alp = torch.polygamma(1, evi_alp_)
    # gamma1_alp0 = torch.polygamma(1, evi_alp0_)

    # gap = labels_1hot_ - evi_alp_ / evi_alp0_
    # loss_mse_ = (gap.pow(2) * gamma1_alp).sum(-1).mean()
    #
    # loss_var_ = (evi_alp_ * (evi_alp0_ - evi_alp_) * gamma1_alp / (evi_alp0_ * evi_alp0_ * (evi_alp0_ + 1))).sum(
    #     -1).mean()

    #
    #
    #
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    if with_bias:
        loss = loglikelihood_err+loglikelihood_var
    else:
        loss = loglikelihood_err
    return loss
    # # loglikelihood = loglikelihood_bias(y, alpha)+loglikelihood_variance(alpha)
    # return loglikelihood

def compute_fisher_inverse(alpha):
    import torch
    K = alpha.shape[1]
    sum_alpha = alpha.sum(dim=1)
    trigamma_alpha = torch.polygamma(1, alpha)
    trigamma_sum_alpha = torch.polygamma(1, sum_alpha).unsqueeze(1)
    diag_elements = trigamma_alpha

    fisher_information_matrices = -trigamma_sum_alpha.repeat(1, K).unsqueeze(2) + torch.diag_embed(diag_elements)
    return fisher_information_matrices, torch.linalg.inv(fisher_information_matrices)


def loglikelihood_variance(alpha):
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    return loglikelihood_var


def loglikelihood_bias(y, alpha):
    S = torch.sum(alpha, dim=1, keepdim=True)
    return torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)


def edl_loss(func, y, alpha):
    S = torch.sum(alpha, dim=1, keepdim=True)
    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
    return A


def kl_regularization(y, alpha, num_classes, epoch_num, annealing_step):
    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes)
    return kl_div


def kl_divergence(alpha, num_classes):
    ones = torch.ones([1, num_classes], dtype=torch.float32)
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
            torch.lgamma(sum_alpha)
            - torch.lgamma(alpha).sum(dim=1, keepdim=True)
            + torch.lgamma(ones).sum(dim=1, keepdim=True)    # always zero？
            - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
            .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
            .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl
