import torch

from evaluation.images.fid_score import FidImageGenerator


class FidImageGeneratorMdvae(FidImageGenerator):
    def _generate_images_from_captions(self, cpt_ft):
        cpt_ft = cpt_ft.to(self.device)
        output = self.model.vaes['x2'].encode(cpt_ft)
        g = output['g']['samples']  # N x K x D
        pz1 = self.model.pz1(*self.model.pz1_params)
        n = cpt_ft.size(0)
        z1 = pz1.sample((n,)).squeeze(-2)  # N x D
        z1 = z1.to(g.device)
        z = torch.cat((z1, g), dim=-1)
        output = self.model.vaes['x1'].decode(z)
        x = output['samples']
        return x
