import argparse
import math
import os
import torch
from torchvision import utils

from basicsr.archs.stylegan2_arch import StyleGAN2Generator
from basicsr.utils import set_random_seed


def generate(args, g_ema, device, mean_latent, randomize_noise):

    with torch.no_grad():
        g_ema.eval()
        for i in range(args.pics):
            sample_z = torch.randn(args.sample, args.latent, device=device)

            sample, _ = g_ema([sample_z],
                              truncation=args.truncation,
                              randomize_noise=randomize_noise,
                              truncation_latent=mean_latent)

            utils.save_image(
                sample,
                f'samples/{str(i).zfill(6)}.png',
                nrow=int(math.sqrt(args.sample)),
                normalize=True,
                range=(-1, 1),
            )


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()

    parser.add_argument('--size', type=int, default=1024)
    parser.add_argument('--sample', type=int, default=1)
    parser.add_argument('--pics', type=int, default=1)
    parser.add_argument('--truncation', type=float, default=1)
    parser.add_argument('--truncation_mean', type=int, default=4096)
    parser.add_argument(
        '--ckpt',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-3ab41b38.pth'  # noqa: E501
    )
    parser.add_argument('--channel_multiplier', type=int, default=2)
    parser.add_argument('--randomize_noise', type=bool, default=True)

    args = parser.parse_args()

    args.latent = 512
    args.n_mlp = 8
    os.makedirs('samples', exist_ok=True)
    set_random_seed(2020)

    g_ema = StyleGAN2Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
    checkpoint = torch.load(args.ckpt)['params_ema']

    g_ema.load_state_dict(checkpoint)

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    generate(args, g_ema, device, mean_latent, args.randomize_noise)
