import os

from stylegan2_pytooch import noise, image_noise, slerp

import torch
from torchvision.utils import save_image


def generate_random(trainer, n_rows):
    print('generating random')
    latents, n = trainer.sample_new_latents(n_rows)
    images = generate_from_latents(trainer, latents, n)
    for i, image in enumerate(images):
        yield image, str(i)

def interpolate(trainer, n_rows):
    print('interpolating')
    latents1 = noise(n_rows, latent_dim=trainer.GAN.G.latent_dim)
    latents2 = noise(n_rows, latent_dim=trainer.GAN.G.latent_dim)
    n = image_noise(n_rows, trainer.image_size)
    for s in torch.arange(0, 1.05, 0.05):
        interp_latents = slerp(s, latents1, latents2)
        latents = [(interp_latents, trainer.GAN.G.num_layers)]
        images = generate_from_latents(trainer, latents, n)
        for i, image, in enumerate(images):
            name = f"{i}_{round(s.item(), 2)}"
            yield image, name

def save_eval_images(folder, base_name, trainer, n_rows, interp=False):
    func = interpolate if interp else generate_random
    for image, name in func(trainer, n_rows):
        path = os.path.join(folder, base_name+'_'+name+'.png')
        save_image(image, path)

def generate_from_latents(trainer, style, n, do_ema=True):
    if do_ema:
        G = trainer.GAN.GE
        S = trainer.GAN.SE
    else:
        G = trainer.GAN.G
        S = trainer.GAN.S
    return trainer.generate_truncated(S, G, style, n, trunc_psi=trainer.trunc_psi)
