"""This version supports stochastic multi-label datasets."""
import torch
from torch import nn

import XXX.uib.information_quantities as iq
from XXX.uib.utils.constants import dbl_null_threshold


class CustomInformationLoss(torch.autograd.Function):
    null_threshold = dbl_null_threshold

    @staticmethod
    def forward(ctx, alpha_Y_Z, alpha_Y, alpha_Y__X, alpha_Z, alpha_Z__X, logits_X_Z, prob_X_Y):
        """
        :param alpha_Y_Z:
        :param alpha_Y:
        :param alpha_Y__X:
        :param alpha_Z:
        :param alpha_Z__X:
        :param logits_X_Z: p(Z|x)
        :param prob_X_Y: p(y|x)
        :return:
        """
        # p(Z|x) for the different x
        p_X_Z = nn.functional.softmax(logits_X_Z, dim=-1, dtype=torch.double)
        prob_X_Y = prob_X_Y.double()

        batch_size, in_capacity = p_X_Z.shape

        # TODO: replace with logsumexp somehow? (and log_softmax)
        # TODO: replace with scatter?

        # p(y, Z)
        p_Y_Z = prob_X_Y.t() @ p_X_Z / batch_size

        J_Y_Z = torch.zeros_like(p_Y_Z)
        if alpha_Y_Z != 0:
            J_Y_Z += -alpha_Y_Z * torch.log(p_Y_Z)

        if alpha_Z != 0:
            p_1_Z = torch.sum(p_Y_Z, dim=0, keepdim=True)
            J_Y_Z += -alpha_Z * torch.log(p_1_Z)

        J_Y_Z[p_Y_Z <= CustomInformationLoss.null_threshold] = 0.0

        information_quantity = torch.sum(p_Y_Z * J_Y_Z)

        if alpha_Z__X != 0:
            J_X_Z = -torch.log(p_X_Z)
            J_X_Z[p_X_Z <= CustomInformationLoss.null_threshold] = 0.0
            information_quantity += alpha_Z__X * torch.sum(p_X_Z * J_X_Z) / batch_size

        # Separating this into its own term ought to improve the gradients.
        if alpha_Y != 0:
            p_Y_1 = torch.sum(p_Y_Z, dim=1, keepdim=True)
            constant_J_Y_Z = -torch.log(p_Y_1)
            constant_J_Y_Z[p_Y_1 <= CustomInformationLoss.null_threshold] = 0.0
            information_quantity += alpha_Y * torch.sum(p_Y_Z * constant_J_Y_Z)

        if alpha_Y__X:
            constant_J_X_Y = -torch.log(prob_X_Y)
            constant_J_X_Y[prob_X_Y <= CustomInformationLoss.null_threshold] = 0.0
            information_quantity += alpha_Y__X * torch.sum(constant_J_X_Y)

        ctx.save_for_backward(J_Y_Z, p_X_Z, prob_X_Y, alpha_Z__X)
        return information_quantity

    @staticmethod
    def backward(ctx, grad_output):
        J_Y_Z, p_X_Z, prob_X_Y, alpha_Z__X = ctx.saved_tensors
        batch_size, _ = prob_X_Y.shape

        # appendix notation
        K_X_Z = prob_X_Y @ J_Y_Z
        Khat_X_1 = torch.bmm(p_X_Z[:, None, :], K_X_Z[:, :, None])[:, :, 0]

        grad_predictions_X_Z = p_X_Z / batch_size * (K_X_Z - Khat_X_1)

        if alpha_Z__X != 0:
            J_X_Z = -torch.log(p_X_Z)
            J_X_Z[p_X_Z <= CustomInformationLoss.null_threshold] = 0.0
            # Do a x-wise dot product on the Zs to compute the entropies
            K_x_1 = torch.bmm(p_X_Z[:, None, :], J_X_Z[:, :, None])[:, :, 0]
            grad_predictions_X_Z += alpha_Z__X * p_X_Z / batch_size * (J_X_Z - K_x_1)

        return None, None, None, None, None, grad_output * grad_predictions_X_Z, None


custom_information_loss = CustomInformationLoss.apply


def iq_loss(information_quantity: torch.Tensor):
    alpha_Y = information_quantity @ iq.H_Y
    alpha_Z = information_quantity @ iq.H_Z
    alpha_Y__X = information_quantity @ iq.H_Y__X
    alpha_Y_Z = information_quantity @ iq.H_Y_Z
    alpha_Z__X = information_quantity @ iq.H_Z__X

    def loss(logits_X_Z, prob_X_Y):
        return custom_information_loss(alpha_Y_Z, alpha_Y, alpha_Y__X, alpha_Z, alpha_Z__X, logits_X_Z, prob_X_Y)

    return loss


edl = iq_loss(iq.entropy_distance)

cel = iq_loss(iq.conditional_entropy)
