import torch

from evaluation.images.fid_score import FidImageGenerator


class FidImageGeneratorMhvae(FidImageGenerator):
    def _generate_images_from_captions(self, cpt_ft):
        cpt_ft = cpt_ft.to(self.device)
        h2 = self.model.encoder['backbone_x2'](cpt_ft)
        h_zeros = torch.zeros(h2.size()).to(self.device)
        h = torch.cat([h_zeros, h2], dim=-1)
        _, g = self.model.encoder['encoder_g'](h)

        _, z = self.model.decoder[f'g_to_z1'](g)
        _, img = self.model.decoder['z1_to_x1'](z)  # N x 3 x 64 x 64
        return img
