import logging
import os

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms

import utils
from data.flowers.main_raw import load_flowers_data
from data.neural_net.run import get_resnet_for_flower_dataset
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class Embedder:
    def __init__(self, model, device, split):
        self.model = model
        self.device = device
        self.dataset, _ = load_flowers_data(mode=split, average=True)

    def get_img_from_cpt(self, **kwargs):
        x = self.dataset.x[1]
        return get_img_from_cpt(self.model, x, self.device, **kwargs)

    def get_cpt_from_img(self, **kwargs):
        x = self.dataset.get_image_tensors()
        return get_cpt_from_img(self.model, x, self.device, **kwargs)

    def get_raw_captions(self):
        """ Get caption feature tensors. """
        cpt_ft = self.dataset.x[1]
        if len(cpt_ft.size()) == 3:
            cpt_ft = cpt_ft[:, 0, :]  # Select first caption
        return cpt_ft

    def get_data(self, **kwargs):
        data = {
            'x1|x2': self.get_img_from_cpt(**kwargs),
            'y': self.dataset.s['y'],
            'x2': self.get_raw_captions(),
            'x2|x1': self.get_cpt_from_img(**kwargs),
        }
        r = ImageEmbedder('resnet', self.device)
        # rf = ImageEmbedder('resnet_ft', self.device)
        x1 = self.dataset.get_image_tensors()
        data.update({
            'x1_emb': r.run(x1),
            'x1|x2_emb': r.run(data['x1|x2']),
            # 'x1_emb_ft': rf.run(x1),
            # 'x1|x2_emb_ft': rf.run(data['x1|x2'])
        })
        return data


def get_img_from_cpt(model, cpt_ft, device, k=1, bs=1024, mode_layer=None):
    """ Generate images from caption features.
    :return: N x D or N x K x D
    """
    if len(cpt_ft.size()) == 3:
        cpt_ft = cpt_ft[:, 0, :]  # Select first caption

    loader = DataLoader(TensorDataset(cpt_ft), batch_size=bs)
    imgs = []
    for ft in loader:
        ft = ft[0].to(device)
        _, posterior = model.vaes[1].bottom_up(ft, k)
        g = posterior['samples']
        ancestral_samples = model.vaes[0].generate(g, mode_layer=mode_layer)
        img = ancestral_samples[0]['samples']
        img = img.squeeze(0) if k == 1 else img.transpose(0, 1)
        imgs.append(img.cpu())
    imgs = torch.cat(imgs)

    return imgs


def get_cpt_from_img(model, img, device, k=1, bs=1024, mode_layer=None):
    """ Generate caption features from images.
    :return: N x D or N x K x D
    """
    loader = DataLoader(TensorDataset(img), batch_size=bs)

    cpt_fts = []
    for img in loader:
        img = img[0].to(device)
        _, posterior = model.vaes[0].bottom_up(img, k)
        g = posterior['samples']
        ancestral_samples = model.vaes[1].generate(g, mode_layer=mode_layer)
        cpt_ft = ancestral_samples[0]['samples']
        cpt_ft = cpt_ft.squeeze(0) if k == 1 else cpt_ft.transpose(0, 1)
        cpt_fts.append(cpt_ft.cpu())
    cpt_fts = torch.cat(cpt_fts)
    return cpt_fts


class ImageEmbedder:
    # https://github.com/pytorch/examples/blob/97304e232807082c2e7b54c597615dc0ad8f6173/imagenet/main.py#L197-L198
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    def __init__(self, name, device):
        self.device = device
        if name == 'resnet':
            self.model = self._load_resnet()
        elif name == 'resnet_ft':
            # Fine-tuned resnet
            self.model = self._load_resnet_finetuned()

    @torch.no_grad()
    def run(self, images: torch.Tensor, bs=1024):
        loader = DataLoader(TensorDataset(images), batch_size=bs)
        ft = []
        for x in loader:
            ft.append(self._forward(x[0]))
        ft = torch.cat(ft)
        return ft

    def _forward(self, x):
        x = self.normalize(x).to(self.device)
        k, n = None, None
        if len(x.size()) == 5:
            k, n = x.size(0), x.size(1)
            x = x.view(-1, *list(x.size())[-3:])
        x = self.model(x).squeeze()
        if k and n:
            x = x.view(k, n, -1)
        return x.cpu()

    def _load_resnet_finetuned(self):
        model = get_resnet_for_flower_dataset()
        ckpt_dir = os.path.join(
            config.dirs['experiments'], 'resnet/flower/runs',
            '2021-11-07_T_17-52-56.718994/model_epoch_50.pt'
        )
        ckpt = utils.torch_load(ckpt_dir)
        model.load_state_dict(ckpt['state_dict'])
        model = list(model.children())[:-1]
        model = nn.Sequential(*model)
        model = model.to(self.device).eval()
        return model

    def _load_resnet(self):
        resnet = torchvision.models.resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        resnet = nn.Sequential(*modules)
        resnet = resnet.to(self.device).eval()
        return resnet
