from modules.vae.transformvae import TransformVAE
from modules.vae import architectures, reconstruction_losses
from modules.latent_space import latentspace


class HypertorusTransformVAE(TransformVAE):
    """docstring for TransformVAE"""

    def __init__(self, input_shape, num_circles, separate_encoders, dist_weight, stop_gradient,
                 log_t_limit=(-10.0, -5.0), architectures="vgg"):
        self.num_circles = num_circles
        self.log_t_limit = log_t_limit
        self.input_shape = input_shape
        self.stop_gradient = stop_gradient
        self.separate_encoders = separate_encoders
        self.dist_weight = dist_weight
        self.architectures = architectures
        kwargs_transformvae = self.set_kwargs_transformvae()
        super().__init__(**kwargs_transformvae)
    @property
    def architectures_function(self):
        # Select the appropriate architecture
        if self.architectures == "vgg":
            architectures_function = architectures.encoder_decoder_vgglike_2d
        elif self.architectures == "dis_lib":
            architectures_function = architectures.encoder_decoder_dislib_2d
        elif self.architectures == "dense":
            architectures_function = architectures.encoder_decoder_dense
        else:
            architectures_function = None
        return architectures_function

    def set_kwargs_transformvae(self):
        height, width, depth = self.input_shape

        encoder, decoder = self.architectures_function(height, width, depth)

        if self.separate_encoders:
            encoders = [encoder]
            for _ in range(self.num_circles - 1):
                encoders.append(self.architectures_function(height, width, depth)[0])
        else:
            encoders = [encoder] * self.num_circles

        latent_spaces = [latentspace.HyperSphericalLatentSpace(1, dist_weight=self.dist_weight,
                                                               name="circle" + str(circle),
                                                               log_t_limit=self.log_t_limit)
                         for circle in range(self.num_circles)]

        reconstruction_loss = reconstruction_losses.gaussian_loss()

        dictionary = {
            "encoders": encoders,
            "decoder": decoder,
            "latent_spaces": latent_spaces,
            "reconstruction_loss": reconstruction_loss,
            "stop_gradient": self.stop_gradient,
        }
        return dictionary
