import functools
import math

import torch

from XXX.uib.modules.categorical_decoder import CategoricalDecoder
from XXX.uib.modules.decoder_interface import DecoderInterface
from XXX.uib.utils.isolated_module import IsolatedModule


class EncoderDecoder(IsolatedModule):
    encoder: torch.nn.Module
    decoder: DecoderInterface

    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, *input):
        return self.decoder(self.encoder(*input))

    def decode(self, encoding):
        return self.decoder(encoding)

    def encoder_metric(self, inner_metric):
        """
        Decodes the predictions using the decoder before applying the metric.

        So one can train on the encoder in for example FastAI transparently.
        """

        @functools.wraps(inner_metric)
        def wrapped_metric(predictions, targets, *args, **kwargs):
            return inner_metric(self.decoder(predictions), targets, *args, **kwargs)

        return wrapped_metric

    def loss_wrapper(self, loss_function):
        def wrapped_loss(predictions, target, *args, **kwargs):
            if self.encoder.training != self.training:
                self.train(self.encoder.training)

            if self.training:
                self.decoder.fit(predictions, target)

            return loss_function(predictions, target, *args, **kwargs)

        return wrapped_loss

    def fit_decoder(self, images, labels):
        labels = self.decoder.convert_tensor(labels, device_only=True)

        with torch.no_grad():
            latents = self.encoder(images)
        latents = self.decoder.convert_tensor(latents)
        self.decoder.fit(latents, labels)


class CategoricalEncoderDecoder(EncoderDecoder):
    almost_one: float = 1

    def forward(self, *inputs_x_):
        latent_x_z = self.encoder(*inputs_x_)
        prediction_X_Y = self.decode(latent_x_z)
        return prediction_X_Y

    def decode(self, latent_x_z):
        decoder_z_y = self.get_decoder_prob_z_y(latent_x_z)
        prediction_X_Y = self.combine(latent_x_z, decoder_z_y)
        return prediction_X_Y

    def combine(self, latent_x_z, decoder_prob_z_y):
        latent_x_z = self.convert_tensor(latent_x_z)
        decoder_prob_z_y = self.convert_tensor(decoder_prob_z_y)

        latent_prob_x_z = torch.nn.functional.softmax(latent_x_z, dim=1)
        predictions_X_Y = latent_prob_x_z @ decoder_prob_z_y
        # The final, overall model returns logits.
        return torch.log(predictions_X_Y)

    def get_decoder_prob_z_y(self, latent_x_z):
        if isinstance(self.decoder, CategoricalDecoder):
            return self.decoder.get_p_Y__Z().t()

        z = latent_x_z.shape[1]

        decoder_prob_z_y = self.decoder(
            torch.eye(z, dtype=latent_x_z.dtype, device=latent_x_z.device) * self.almost_one
            + torch.ones(z, dtype=latent_x_z.dtype, device=latent_x_z.device) * (1.0 - self.almost_one) / z
        )
        return decoder_prob_z_y
