"""This version only supports deterministic single label setups."""
import torch
from torch import nn

import XXX.uib.information_quantities as iq
from XXX.uib.utils.constants import dbl_null_threshold


class CustomInformationLoss(torch.autograd.Function):
    null_threshold = dbl_null_threshold

    @staticmethod
    def forward(ctx, alpha_Y_Z, alpha_Y, alpha_Z, alpha_Z__X, logits_X_Z, labels_x):
        # p(Z|x) for the different x
        p_X_Z = nn.functional.softmax(logits_X_Z, dim=-1, dtype=torch.double)

        batch_size, in_capacity = p_X_Z.shape

        i_y_x = torch.nn.functional.one_hot(labels_x).t().double()
        # TODO: replace with logsumexp somehow? (and log_softmax)
        # TODO: replace with scatter?

        # p(y, Z)
        p_Y_Z = i_y_x @ p_X_Z / batch_size

        J_Y_Z = torch.zeros_like(p_Y_Z)
        if alpha_Y_Z != 0:
            J_Y_Z += -alpha_Y_Z * torch.log(p_Y_Z)

        if alpha_Z != 0:
            p_1_Z = torch.sum(p_Y_Z, dim=0, keepdim=True)
            J_Y_Z += -alpha_Z * torch.log(p_1_Z)

        J_Y_Z[p_Y_Z <= CustomInformationLoss.null_threshold] = 0.0

        information_quantity = torch.sum(p_Y_Z * J_Y_Z)

        # Separating this into its own term ought to improve the gradients.
        if alpha_Y != 0:
            p_Y_1 = torch.sum(p_Y_Z, dim=1, keepdim=True)
            constant_J_Y_Z = -alpha_Y * torch.log(p_Y_1)
            constant_J_Y_Z[p_Y_1 <= CustomInformationLoss.null_threshold] = 0.0
            information_quantity += torch.sum(p_Y_Z * constant_J_Y_Z)

        if alpha_Z__X != 0:
            J_X_Z = -torch.log(p_X_Z)
            J_X_Z[p_X_Z <= CustomInformationLoss.null_threshold] = 0.0
            information_quantity += alpha_Z__X * torch.sum(p_X_Z * J_X_Z) / batch_size

        ctx.save_for_backward(J_Y_Z, p_X_Z, labels_x, alpha_Z__X)
        return information_quantity

    @staticmethod
    def backward(ctx, grad_output):
        J_Y_Z, p_X_Z, labels_x, alpha_Z__X = ctx.saved_tensors
        (batch_size,) = labels_x.shape

        K_X_Y = p_X_Z @ J_Y_Z.t()
        K_x_1 = torch.gather(K_X_Y, dim=1, index=labels_x[:, None])
        J_X_Z = J_Y_Z.index_select(dim=0, index=labels_x)
        grad_predictions_X_Z = p_X_Z / batch_size * (J_X_Z - K_x_1)

        if alpha_Z__X != 0:
            J_X_Z = -torch.log(p_X_Z)
            J_X_Z[p_X_Z <= CustomInformationLoss.null_threshold] = 0.0
            # Do a x-wise dot product on the Zs to compute the entropies
            K_x_1 = torch.bmm(p_X_Z[:, None, :], J_X_Z[:, :, None])[:, :, 0]
            grad_predictions_X_Z += alpha_Z__X * p_X_Z / batch_size * (J_X_Z - K_x_1)

        return None, None, None, None, grad_output * grad_predictions_X_Z, None


custom_information_loss = CustomInformationLoss.apply


def iq_loss(information_quantity: torch.Tensor):
    alpha_Y = information_quantity @ iq.H_Y
    alpha_Z = information_quantity @ iq.H_Z
    alpha_Y_Z = information_quantity @ iq.H_Y_Z
    alpha_Z__X = information_quantity @ iq.H_Z__X

    def loss(logits_X_Z, labels_x):
        return custom_information_loss(alpha_Y_Z, alpha_Y, alpha_Z, alpha_Z__X, logits_X_Z, labels_x)

    return loss


edl = iq_loss(iq.entropy_distance)

cel = iq_loss(iq.conditional_entropy)
