import torch
from torch.utils.data import DataLoader, TensorDataset

from evaluation.images.embedder import Embedder


class EmbedderMdvae(Embedder):
    def get_img_from_cpt(self, **kwargs):
        x = self.dataset.x[1]
        return get_img_from_cpt(self.model, x, self.device, **kwargs)

    def get_cpt_from_img(self, **kwargs):
        x = self.dataset.get_image_tensors()
        return get_cpt_from_img(self.model, x, self.device, **kwargs)


def get_img_from_cpt(model, cpt_ft, device, k=1, bs=1024):
    """ Generate images from caption features.
    :return: N x D or N x K x D
    """
    if len(cpt_ft.size()) == 3:
        cpt_ft = cpt_ft[:, 0, :]  # Select first caption

    loader = DataLoader(TensorDataset(cpt_ft), batch_size=bs)
    imgs = []
    for cpt_ft in loader:
        cpt_ft = cpt_ft[0].to(device)
        output = model.vaes['x2'].encode(cpt_ft, k=k)
        g = output['g']['samples']  # K x N x D
        pz1 = model.pz1(*model.pz1_params)
        z1 = pz1.sample(list(g.size())[:-1]).squeeze(-2)  # K x N x D
        z = torch.cat((z1, g), dim=-1)
        output = model.vaes['x1'].decode(z)
        img = output['samples']
        img = img.squeeze(0) if k == 1 else img.transpose(0, 1)
        imgs.append(img.cpu())
    imgs = torch.cat(imgs)

    return imgs


def get_cpt_from_img(model, img, device, k=1, bs=1024):
    """ Generate caption features from images.
    :return: N x D or N x K x D
    """
    loader = DataLoader(TensorDataset(img), batch_size=bs)

    cpt_fts = []
    for img in loader:
        img = img[0].to(device)
        output = model.vaes['x1'].encode(img, k=k)
        g = output['g']['samples']  # K x N x D
        pz2 = model.pz2(*model.pz2_params)
        z2 = pz2.sample(list(g.size())[:-1]).squeeze(-2)  # K x N x D
        z = torch.cat((g, z2), dim=-1)
        output = model.vaes['x2'].decode(z)
        cpt_ft = output['samples']
        cpt_ft = cpt_ft if k == 1 else cpt_ft.transpose(0, 1)
        cpt_fts.append(cpt_ft.cpu())
    cpt_fts = torch.cat(cpt_fts)

    return cpt_fts
