from dataclasses import dataclass

import torch
import math


@dataclass
class CrossEntropy:
    _sum: torch.Tensor
    _count: int = 1

    def __add__(self, other: "CrossEntropy"):
        return CrossEntropy(self._sum + other._sum, self._count + other._count)

    @property
    def mean(self):
        return self._sum / self._count

    @property
    def sum(self):
        return self._sum

    @property
    def expanded(self):
        # TODO: file a bug against Ignite (they ought not require this extra dimension to sum over the batch!
        return self.mean.expand(self._count, 1)


@dataclass
class CrossEntropies:
    prediction_value: torch.Tensor
    decoder_value: torch.Tensor
    count: int

    @property
    def prediction(self):
        return CrossEntropy(self.prediction_value, self.count)

    @property
    def decoder(self):
        return CrossEntropy(self.decoder_value, self.count)

    def __add__(self, other: "CrossEntropies"):
        return CrossEntropies(
            self.prediction_value + other.prediction_value,
            self.decoder_value + other.decoder_value,
            self.count + other.count,
        )


def deterministic_prediction_cross_entropy(prediction_logits_x_y, labels_x, reduction="mean"):
    """Deterministic encoder model that outputs prediction logits.

    In this case, decoder cross-entropy == prediction cross-entropy.
    """

    cross_entropy = torch.nn.functional.cross_entropy(prediction_logits_x_y, labels_x, reduction=reduction)

    return cross_entropy


def deterministic_continuous_encoder_cross_entropies(prediction_logits_x_y, labels_x):
    cross_entropy = deterministic_prediction_cross_entropy(prediction_logits_x_y, labels_x, reduction="sum")
    return CrossEntropies(cross_entropy, cross_entropy, len(labels_x))


def deterministic_categorical_encoder_loss(latent_logits_x_z, prediction_logits_z_y, labels_x):
    raise NotImplementedError()


def deterministic_categorical_encoder_decoder_cross_entropy(
    latent_logits_x_z, prediction_prob_z_y, labels_x, reduction="mean"
):
    """
    Deterministic model that outputs prediction logits and latent logits.

    This requires a special dynamics model, where we take each z one-hot and feed it into the decoder separately.
    """
    # noinspection PyTypeChecker
    latent_prob_x_z = torch.nn.functional.softmax(latent_logits_x_z, dim=1, dtype=torch.double)

    # We need normalized logits.
    prediction_logits_x_y = latent_prob_x_z @ torch.log(prediction_prob_z_y.double())

    cross_entropy_terms = torch.gather(prediction_logits_x_y, dim=1, index=labels_x[:, None])[:, 0]

    if reduction is None:
        cross_entropy = -cross_entropy_terms
    elif reduction == "sum":
        cross_entropy = -torch.sum(cross_entropy_terms)
    elif reduction == "mean":
        cross_entropy = -torch.mean(cross_entropy_terms)
    else:
        raise ValueError(f"Unknown reduction mode {reduction}!")

    return cross_entropy


def deterministic_categorical_encoder_prediction_cross_entropy(
    latent_logits_x_z, prediction_prob_z_y, labels_x, reduction="mean"
):
    """
    Deterministic model that outputs prediction logits and latent logits.

    The dynamics model stops the gradients, and here we connect them again.

    We perform the step `p_\Theta(\hat y|x) = E_p_\Theta(z|x) p_\Theta(\hat y| z)` and then compute the cross-entropy
    with the `labels_x`
    """
    # noinspection PyTypeChecker
    latent_prob_x_z = torch.nn.functional.softmax(latent_logits_x_z, dim=1, dtype=torch.double)
    prediction_logits_x_y = torch.log(latent_prob_x_z @ prediction_prob_z_y.double())

    cross_entropy_terms = torch.gather(prediction_logits_x_y, dim=1, index=labels_x[:, None])[:, 0]

    if reduction is None:
        cross_entropy = -cross_entropy_terms
    elif reduction == "sum":
        cross_entropy = -torch.sum(cross_entropy_terms)
    elif reduction == "mean":
        cross_entropy = -torch.mean(cross_entropy_terms)
    else:
        raise ValueError(f"Unknown reduction mode {reduction}!")

    return cross_entropy


def deterministic_categorical_encoder_cross_entropies(latent_logits_x_z, prediction_prob_z_y, labels_x):
    prediction_entropy = deterministic_categorical_encoder_prediction_cross_entropy(
        latent_logits_x_z, prediction_prob_z_y, labels_x, reduction="sum"
    )

    decoder_entropy = deterministic_categorical_encoder_decoder_cross_entropy(
        latent_logits_x_z, prediction_prob_z_y, labels_x, reduction="sum"
    )

    return CrossEntropies(prediction_entropy, decoder_entropy, len(labels_x))


def stochastic_continuous_encoder_cross_entropies(prediction_logits_x_k_y, labels_x):
    prediction_entropy = stochastic_continuous_encoder_prediction_cross_entropy(
        prediction_logits_x_k_y, labels_x, reduction="sum"
    )

    decoder_entropy = stochastic_continuous_encoder_decoder_cross_entropy(
        prediction_logits_x_k_y, labels_x, reduction="sum"
    )

    return CrossEntropies(prediction_entropy, decoder_entropy, len(labels_x))


def stochastic_continuous_encoder_decoder_cross_entropy(logits_x_k_y, labels_x, reduction="mean"):
    labels_x = labels_x.to(device=logits_x_k_y.device, non_blocking=True)

    num_samples = logits_x_k_y.shape[1]
    expanded_labels_x = labels_x[:, None].expand(-1, num_samples)
    flattened_logits_x_y = logits_x_k_y.flatten(0, 1)
    flattened_expanded_labels_x = expanded_labels_x.flatten(0, 1)
    decoder_entropy = torch.nn.functional.cross_entropy(
        flattened_logits_x_y, flattened_expanded_labels_x, reduction=reduction
    )
    if reduction == "sum":
        # The sum averaging does not know about samples.
        decoder_entropy /= num_samples

    return decoder_entropy


def stochastic_continuous_encoder_prediction_cross_entropy(logits_x_k_y, labels_x, reduction="mean"):
    K = logits_x_k_y.shape[1]

    labels_x = labels_x.to(device=logits_x_k_y.device, non_blocking=True)

    log_probits_x_y = torch.nn.functional.log_softmax(logits_x_k_y, dim=2).logsumexp(dim=1, keepdim=False) - math.log(K)
    prediction_ce = torch.nn.functional.nll_loss(log_probits_x_y, labels_x, reduction=reduction)
    return prediction_ce
