from modules.vae.transformvae import TransformVAE
from modules.vae import architectures, reconstruction_losses
from modules.latent_space import latentspace


class StandardTransformVAE(TransformVAE):
    """docstring for TransformVAE"""

    def __init__(self, input_shape, dim, stop_gradient=False, kl_weight=1.0,
                 ):
        self.dim = dim
        self.input_shape = input_shape
        self.stop_gradient = stop_gradient
        self.kl_weight = kl_weight
        self.dist_weight = 1
        kwargs_transformvae = self.set_kwargs_transformvae()
        super().__init__(**kwargs_transformvae)

    def set_kwargs_transformvae(self):
        height, width, depth = self.input_shape
        encoder, decoder = architectures.encoder_decoder_vgglike_2d(height, width, depth)
        encoders = [encoder]

        latent_spaces = [latentspace.GaussianLatentSpace(self.dim, dist_weight=self.dist_weight,
                                                         name="standard", kl_weight=self.kl_weight
                                                         )]

        reconstruction_loss = reconstruction_losses.gaussian_loss()

        dictionary = {
            "encoders": encoders,
            "decoder": decoder,
            "latent_spaces": latent_spaces,
            "reconstruction_loss": reconstruction_loss,
            "stop_gradient": self.stop_gradient,
        }
        return dictionary
