import logging
import os
import torch
from pathlib import Path
from torchvision.utils import save_image

from pytorch_fid.fid_score import calculate_fid_given_paths
from utils import chunks

logger = logging.getLogger('custom')


class FidImageGenerator:
    def __init__(self, model, dataset, device, save_dir):
        self.model = model
        self.dataset = dataset
        self.device = device
        self.save_dir = save_dir

    def run(self, **kwargs):
        indices = list(range(0, self.dataset.s['y'].size(0)))
        c = chunks(indices, 256)
        indices = [v for v in c]
        for cur_idx in indices:
            cpt_fts = self.dataset.x[1][cur_idx]
            gen_imgs = self._generate_images_from_captions(cpt_fts, **kwargs)
            ys = self.dataset.s['y'][cur_idx]
            paths = self.dataset.x[0][cur_idx]
            self._save_images(gen_imgs, ys, paths)

    def _generate_images_from_captions(self, cpt_ft, mode_layer=None):
        # x2 -> g
        vae = self.model.vaes[1]
        _, posterior = vae.bottom_up(cpt_ft.to(self.device))
        # shape captions x D
        g = posterior['samples'].squeeze()

        # g -> x1
        vae = self.model.vaes[0]
        ancestral_samples = vae.generate(g, mode_layer=mode_layer)
        img = ancestral_samples[0]['samples']

        return img

    def _save_images(self, gen_imgs, ys, paths):
        os.makedirs(self.save_dir, exist_ok=True)
        for gen_img, y, path in zip(gen_imgs, ys, paths):
            dst = os.path.join(self.save_dir, Path(path).name)
            save_image(gen_img, dst)


def calc_fid(dirs, device):
    fid = calculate_fid_given_paths(
        paths=[dirs['p'], dirs['q']],
        batch_size=256,
        device=device,
        dims=2048  # default value, see
        # https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py#L62
    )
    logger.info(f'FID: {fid:.1f}')
    return fid
