from modules.vae.transformvae import TransformVAE
from modules.vae import architectures, reconstruction_losses
from modules.latent_space import latentspace2


class HypercylinderTransformVAE(TransformVAE):
    """docstring for TransformVAE"""

    def __init__(self, input_shape, separate_encoders, dist_weight, stop_gradient, latent_dim,
                 log_t_limit=(-10.0, -5.0), kl_weight=1.0, architecture="vgg"):
        self.log_t_limit = log_t_limit
        self.input_shape = input_shape
        self.stop_gradient = stop_gradient
        self.separate_encoders = separate_encoders
        self.dist_weight = dist_weight
        self.architectures = architecture
        self.kl_weight = kl_weight
        self.latent_dim = latent_dim
        kwargs_transformvae = self.set_kwargs_transformvae()
        super().__init__(**kwargs_transformvae)

    @property
    def architectures_function(self):
        # Select the appropriate architecture
        if self.architectures == "vgg":
            architectures_function = architectures.encoder_decoder_vgglike_2d
        elif self.architectures == "dis_lib":
            architectures_function = architectures.encoder_decoder_dislib_2d
        elif self.architectures == "dense":
            architectures_function = architectures.encoder_decoder_dense
        else:
            architectures_function = None
        return architectures_function

    def set_kwargs_transformvae(self):
        height, width, depth = self.input_shape

        encoder, decoder = architectures.get_encoder_decoder(self.architectures, self.input_shape)
        # self.architectures_function(height, width, depth)

        if self.separate_encoders:
            encoders = [encoder]
            for _ in range(2):
                encoders.append(self.architectures_function(height, width, depth)[0])
        else:
            encoders = [encoder] * 2

        latent_spaces = [latentspace2.HyperSphericalLatentSpace(1, dist_weight=self.dist_weight,
                                                                name="circle",
                                                                log_t_limit=self.log_t_limit),
                         latentspace2.GaussianLatentSpace(dim=self.latent_dim,
                                                          kl_weight=self.kl_weight)]

        reconstruction_loss = reconstruction_losses.gaussian_loss()

        dictionary = {
            "encoders": encoders,
            "decoder": decoder,
            "latent_spaces": latent_spaces,
            "reconstruction_loss": reconstruction_loss,
            "stop_gradient": self.stop_gradient,
        }
        return dictionary
