"""Multimodal Distribution Translator"""
import torch

from utils.names import ReconstructionLoss, TrainingStrategy
from utils.metrics.clip import calculate_clip_loss


class ModalityTranslationBridge(torch.nn.Module):
    def __init__(self,
                 bridge_model,
                 encoder_x,
                 encoder_y,
                 decoder_x,
                 rec_loss_type,
                 clip_loss_w,
                 training_strategy,
                 decoder_y=None):
        super().__init__()
        self.bridge_model = bridge_model
        self.encoder_x = encoder_x
        self.encoder_y = encoder_y
        self.decoder_x = decoder_x
        self.decoder_y = decoder_y
        self.rec_loss_type = rec_loss_type
        self.clip_loss_w = clip_loss_w
        self.training_strategy = training_strategy

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        z_x = self.encoder_x(x)
        z_y = self.encoder_y(y)

        z_x_bridge = self.bridge_model(z_x, z_y)

        x_hat = self.decoder_x(z_x)
        x_bridge_hat = self.decoder_x(z_x_bridge)
        y_hat = self.decoder_y(z_y) if self.decoder_y is not None else None

        return {'z_x': z_x, 'z_y': z_y, 'z_x_bridge': z_x_bridge, 'x_hat': x_hat, 'x_bridge_hat': x_bridge_hat, 'y': y,
                'y_hat': y_hat}

    def sample(self, y: torch.Tensor, sampling_steps=40):
        z_y = self.encoder_y(y)
        z_x_hat = self.bridge_model.sample(z_y, sampling_steps)
        x_hat = self.decoder_x(z_x_hat)

        return x_hat

    def loss(self, loss_comp, step):

        # --- bridge loss --- #
        losses = {'diffusion_loss': ((loss_comp['z_x'] - loss_comp['z_x_bridge']) ** 2).mean()}

        if self.iterative_training_coordinator(step):
            # --- reconstruction loss --- #
            if self.rec_loss_type == ReconstructionLoss.Reconstruction.value:
                losses['reconstruction_loss'] = ((loss_comp['x_hat'] - loss_comp['x']) ** 2).mean()

            elif self.rec_loss_type == ReconstructionLoss.Predictive.value:
                losses['reconstruction_loss'] = ((loss_comp['x_bridge_hat'] - loss_comp['x']) ** 2).mean()

            # --- clip loss --- #
            if self.clip_loss_w > 0:
                losses['clip_loss'] = self.clip_loss_w * calculate_clip_loss(loss_comp['z_x'], loss_comp['z_y'])

        # --- final loss --- #
        total_loss = 0
        for v in losses.values():
            total_loss = total_loss + v

        losses['total_loss'] = total_loss

        return total_loss, losses

    def iterative_training_coordinator(self, step):
        if self.training_strategy == TrainingStrategy.WholeSystemTraining.value:
            return True # if its regular training it always should enter
        elif self.training_strategy == TrainingStrategy.IterativeTraining.value:
            return step % 4 == 0
        else:
            raise NotImplementedError('No such training strategy')