from dataclasses import dataclass
from typing import Optional

import torch

from XXX.uib.losses import very_approx_regularizers
from XXX.uib.losses.cross_entropies import CrossEntropies


@dataclass
class IgniteOutput:
    y: torch.Tensor
    y_pred: torch.Tensor

    z: Optional[torch.Tensor]
    loss: Optional[float]

    cross_entropies: CrossEntropies

    def get_loss(self):
        return self.loss

    def get_y_pred_y(self):
        return self.y_pred, self.y

    def get_y(self):
        return self.z

    def get_prediction_cross_entropy(self):
        return self.cross_entropies.prediction.expanded

    def get_decoder_cross_entropy(self):
        return self.cross_entropies.decoder.expanded

    @staticmethod
    def call_z_y(func):
        def wrapper(self):
            return func(self.z, self.y)

        return wrapper

    @staticmethod
    def get_mean_squared_z(stochastic=False):
        return IgniteOutput.call_z_y(very_approx_regularizers.squared_sum(stochastic=stochastic))

    @staticmethod
    def get_mean_squared_mean_z_given_x(stochastic=False):
        return IgniteOutput.call_z_y(very_approx_regularizers.squared_mean_by_x(stochastic=stochastic))

    @staticmethod
    def get_covariance_trace(covariance_trace_estimator, stochastic=False):
        return IgniteOutput.call_z_y(covariance_trace_estimator(stochastic=stochastic))

    @staticmethod
    def get_entropy_estimate(covariance_trace_estimator, stochastic=False):
        return IgniteOutput.call_z_y(
            very_approx_regularizers.estimate_entropy(covariance_trace_estimator, stochastic=stochastic)
        )
