from torch import nn

from mhvae_vasco.model import MHVAE
from mhvae_vasco.model.images.decoder import LowerDecoderX1, UpperDecoderX1, LowerDecoderX2, UpperDecoderX2
from mhvae_vasco.model.images.encoder import BackboneX1, BackboneX2, EncoderG, EncoderZ


class MhvaeImages(MHVAE):
    @staticmethod
    def build_encoder(args):
        module_dict = nn.ModuleDict({
            'backbone_x1': BackboneX1(),
            'backbone_x2': BackboneX2(),
            'encoder_g': EncoderG(
                latent_size=args.stoc_dim['g'],
            ),
            'encoder_z1': EncoderZ(
                latent_size=args.stoc_dim['z1'],
            ),
            'encoder_z2': EncoderZ(
                latent_size=args.stoc_dim['z2'],
            ),
        })
        return module_dict

    @staticmethod
    def build_decoder(args):
        module_dict = nn.ModuleDict({
            'z1_to_x1': LowerDecoderX1(
                latent_size=args.stoc_dim['z1']
            ),
            'g_to_z1': UpperDecoderX1(
                g_size=args.stoc_dim['g'],
                z_size=args.stoc_dim['z1']
            ),
            'z2_to_x2': LowerDecoderX2(
                latent_size=args.stoc_dim['z2']
            ),
            'g_to_z2': UpperDecoderX2(
                g_size=args.stoc_dim['g'],
                z_size=args.stoc_dim['z2']
            ),
        })
        return module_dict
