import torch
from torch.utils.data import DataLoader, TensorDataset

from evaluation.images.embedder import Embedder


class EmbedderVasco(Embedder):
    def get_img_from_cpt(self, **kwargs):
        x = self.dataset.x[1]
        return get_img_from_cpt(self.model, x, self.device, **kwargs)

    def get_cpt_from_img(self, **kwargs):
        x = self.dataset.get_image_tensors()
        return get_cpt_from_img(self.model, x, self.device, **kwargs)


def get_img_from_cpt(model, cpt_ft, device, k=1, bs=1024):
    """ Generate images from caption features.
    :return: N x D or N x K x D
    """
    if len(cpt_ft.size()) == 3:
        cpt_ft = cpt_ft[:, 0, :]  # Select first caption

    loader = DataLoader(TensorDataset(cpt_ft), batch_size=bs)
    imgs = []
    for cpt_ft in loader:
        cpt_ft = cpt_ft[0].to(device)
        h2 = model.encoder['backbone_x2'](cpt_ft)
        h_zeros = torch.zeros(h2.size()).to(device)
        h = torch.cat([h_zeros, h2], dim=-1)
        _, g = model.encoder['encoder_g'](h, k=k)
        _, z = model.decoder[f'g_to_z1'](g)
        _, img = model.decoder['z1_to_x1'](z)
        img = img.squeeze(0) if k == 1 else img.transpose(0, 1)
        imgs.append(img.cpu())
    imgs = torch.cat(imgs)

    return imgs


def get_cpt_from_img(model, img, device, k=1, bs=1024):
    """ Generate caption features from images.
    :return: N x D or N x K x D
    """
    loader = DataLoader(TensorDataset(img), batch_size=bs)

    cpt_fts = []
    for img in loader:
        img = img[0].to(device)
        h1 = model.encoder['backbone_x1'](img)
        h_zeros = torch.zeros(h1.size()).to(device)
        h = torch.cat([h1, h_zeros], dim=-1)
        _, g = model.encoder['encoder_g'](h, k=k)
        _, z = model.decoder[f'g_to_z2'](g)
        _, cpt_ft = model.decoder['z2_to_x2'](z)
        cpt_ft = cpt_ft.squeeze(0) if k == 1 else cpt_ft.transpose(0, 1)
        cpt_fts.append(cpt_ft.cpu())
    cpt_fts = torch.cat(cpt_fts)

    return cpt_fts
