import sys

sys.path.append('./models/stylegan2')

import argparse

import dnnlib.tflib as tflib
import numpy as np

from decomposition import apply_pca_grid_fig7, load_network, pca
from utils import out_classes

parser = argparse.ArgumentParser(description='Run stylegan1/2 (Default is 2)')
parser.add_argument('-s', '--stylegan_version', type=int, default=2,
                    help='StyleGAN version')
parser.add_argument('--class', dest='output_class', type=str, default=None,
                    help='Output class to generate (BigGAN: Imagenet, ProGAN: LSUN)')
parser.add_argument('--est', dest='estimator', type=str, default='ipca',
                    help='The algorithm to use [pca, fbpca, cupca, spca, ica]')
parser.add_argument('-b', dest='batch_size', type=int, default=None,
                    help='Minibatch size, leave empty for automatic detection')
parser.add_argument('-c', dest='components', type=int, default=80,
                    help='Number of components to keep')
parser.add_argument('-n', '--samples', type=int, default=10_000,
                    help='Number of examples to use in decomposition')
# parser.add_argument('--use_w', action='store_true',
#                     help='Use W latent space (StyleGAN(2))')
parser.add_argument('--use_w', type=bool, default=True,
                    help='Use W latent space (StyleGAN(2))')
parser.add_argument('--seed_compute', type=int, default=None,
                    help='Seed used in decomposition')
parser.add_argument('--seed', type=int, default=None,
                    help='Seed used to generate the images')
parser.add_argument('-t', '--truncation_psi', type=float, default=0.7,
                    help='Truncation-psi value used in the Truncation trick')
parser.add_argument('-f', '--force_recompute', type=bool, default=False,
                    help='Boolean to force re-computation of the PCA components')
args = parser.parse_args()

tflib.init_tf()

# stylegan1/2 pretrained on ffhq
stylegan_version = args.stylegan_version
output_class = args.output_class
network = out_classes[stylegan_version][output_class]
estimator = args.estimator
batch_size = args.batch_size
num_components = args.components
num_samples = args.samples
use_w = args.use_w
seed = args.seed
force_recompute = args.force_recompute

Gs, Gs_kwargs = load_network(output_class, stylegan_version)
Gs_kwargs.truncation_psi = args.truncation_psi

# TODO: Use this to get the name of the layers for the --layers functionality
# Gs.print_layers()

# Noise
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

rnd = np.random.RandomState(0)
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})  # [height, width]

# We need torch only to generate identical noise variables
# import torch
# torch.manual_seed(0)
# tflib.set_vars({var: torch.randn(*var.shape.as_list(), device="cuda").cpu().numpy() for var in noise_vars})  # [height, width]

dump_path = pca(Gs, stylegan_version, output_class, estimator, batch_size, num_components, num_samples)
w_avg = Gs.get_var('dlatent_avg')

# Custom interpretable edits discovered by selective application of latent edits
# across the layers of several pretrained GAN models

# StyleGAN2
# python3 run_stylegan.py -b 20 -n 1_000_000 --truncation_psi=0.7 --class cars
# StyleGAN1
# python3 run_stylegan.py -s 1 -b 20 -n 1_000_000 --truncation_psi=1.0 --class wikiart
edits = ['Head rotation', 'Simple strokes']
apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)

# Partial forward
# w_map_temp = Gs.components.mapping.clone(mapping_layers = 5)
