"""
Figure 7, A selection of interpretable edits discovered by selective application of
latent edits across the layers of several pretrained GAN models.
"""

import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from figure_configs import configs_fig7 as configs
from decomposition import load_network, pca
from utils import centre_strip_stylegan

SAVE_PATH = './results'
os.makedirs(SAVE_PATH, exist_ok=True)


def apply_pca_fig7(Gs, Gs_kwargs, edits, dump_path):
    out_root = Path('out/directions')
    os.makedirs(out_root, exist_ok=True)
    B = 5

    num_imgs_per_example = 1

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]

    grid = []
    strips = []

    configurations = []
    for edit in edits:
        for c in configs:
            if c[9] == edit:
                configurations.append(c)

    print("Configs:", configurations)

    for config_id, (model_name, layer, mode, latent_space, l_start, l_end, classname, sigma, idx, title, seeds) in enumerate(configurations[:]):
        print(f'{model_name}, {layer}, {title}')

        # TODO: See what is to be done for BigGAN
        # if 'BigGAN' in model_name:
        #     model.truncation = 0.6
        # elif 'StyleGAN2' in model_name:
        #     model.truncation = 0.7

        # TODO: See what is to be done for BigGAN
        # if latent_space == 'w':
        #     model.use_w()
        # elif hasattr(model, 'use_z'):
        #     model.use_z()

        # TODO: See what is to be done for BigGAN
        # # Special case: BigGAN512-deep, gen_z: class-independent
        # if model_name == 'BigGAN-512' and layer == 'generator.gen_z':
        #     config.output_class = 'husky'  # chosen class doesn't matter

        with np.load(dump_path) as data:
            X_comp = data['act_comp']
            X_global_mean = data['act_mean']
            X_stdev = data['act_stdev']
            Z_comp = data['lat_comp']
            Z_global_mean = data['lat_mean']
            Z_stdev = data['lat_stdev']

        feat_shape = X_comp[0].shape
        sample_dims = np.prod(feat_shape)

        # X_comp = np.reshape(X_comp, [-1, *feat_shape])  # -1, 1, C, H, W
        # X_global_mean = np.reshape(X_global_mean, [*feat_shape])  # 1, C, H, W

        # Creates additional seeds
        # num_seeds = ((num_imgs_per_example - 1) // B + 1) * B  # make divisible
        # max_seed = np.iinfo(np.int32).max
        # seeds = np.concatenate((seeds, np.random.randint(0, max_seed, num_seeds)))
        # seeds = seeds[:num_seeds].astype(np.int32)
        # latents = [model.sample_latent(1, seed=s) for s in seeds]

        # Range is exclusive, in contrast to notation in paper
        edit_start = l_start
        edit_end = num_layers if l_end == -1 else l_end

        # batch_frames = create_strip_centered(inst, mode, layer, latents, components.X_comp[idx],
        #                                      components.Z_comp[idx], components.X_stdev[idx], components.Z_stdev[idx],
        #                                      components.X_global_mean, components.Z_global_mean, sigma, edit_start, edit_end)
        # save_frames(f'{config_id}_{title}_{mode}', model_name, out_root, batch_frames)

        print("Seeds:", seeds)
        for s in seeds:
            rnd = np.random.RandomState(s)
            z = rnd.randn(1, *Gs.input_shape[1:])

            batch_frames = centre_strip_stylegan(Gs, Gs_kwargs, z, Z_comp, Z_global_mean, Z_stdev, idx, sigma, 5, edit_start, edit_end)
            strips.append(np.hstack(batch_frames))

        # edit_name = prettify_name(title)
        # outidr = out_root / model_name / classname / edit_name
        # os.makedirs(outidr, exist_ok=True)

        # for ex, frames in enumerate(batch_frames):
        #     for i, frame in enumerate(frames):
        #         Image.fromarray(np.uint8(frame*255)).save(outidr / f'cmp{idx}_s{edit_start}_e{edit_end}_{seeds[ex]}_{i}.png')

        # Show first
        grid = np.vstack(strips)
        plt.figure(figsize=(15, 15))
        plt.imshow(grid, interpolation='bilinear')
        plt.axis('off')
        plt.savefig(os.path.join(SAVE_PATH, 'figure7_{}_{}_{}.png'.format(model_name, classname, title.replace(' ', '-'))))


num_samples = 1_000_000
batch_size = 20


# StyleGAN1 - wikiart
# python3 run_stylegan.py -s 1 -b 20 -n 1_000_000 --truncation_psi=1.0 --class wikiart
# apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)
truncation_psi = 1.0
out_class = 'wikiart'

Gs, Gs_kwargs = load_network(out_class, 1)
Gs_kwargs.truncation_psi = truncation_psi
dump_path = pca(Gs, 1, out_class, batch_size=batch_size, num_samples=num_samples)

edits_stylegan1 = ['Head rotation', 'Simple strokes']
apply_pca_fig7(Gs, Gs_kwargs, edits_stylegan1, dump_path)


# StyleGAN2 - cars
# python3 run_stylegan.py -b 20 -n 1_000_000 --truncation_psi=0.7 --class cars
# apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)
truncation_psi = 0.7
out_class = 'cars'

Gs, Gs_kwargs = load_network(out_class)
Gs_kwargs.truncation_psi = truncation_psi
dump_path = pca(Gs, 2, out_class, batch_size=batch_size, num_samples=num_samples)

edits_stylegan2 = ['Reflections']
apply_pca_fig7(Gs, Gs_kwargs, edits_stylegan2, dump_path)


# StyleGAN2 - horse
# python3 run_stylegan.py -b 20 -n 1_000_000 --truncation_psi=0.7 --class horse
# apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)
truncation_psi = 0.7
out_class = 'horse'

Gs, Gs_kwargs = load_network(out_class)
Gs_kwargs.truncation_psi = truncation_psi
dump_path = pca(Gs, 2, out_class, batch_size=batch_size, num_samples=num_samples)

edits_stylegan2 = ['Add rider']
apply_pca_fig7(Gs, Gs_kwargs, edits_stylegan2, dump_path)


# StyleGAN2 - cats
# python3 run_stylegan.py -b 20 -n 1_000_000 --truncation_psi=0.7 --class cats
# apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)
truncation_psi = 0.7
out_class = 'cats'

Gs, Gs_kwargs = load_network(out_class)
Gs_kwargs.truncation_psi = truncation_psi
dump_path = pca(Gs, 2, out_class, batch_size=batch_size, num_samples=num_samples)

edits_stylegan2 = ['Fluffiness']
apply_pca_fig7(Gs, Gs_kwargs, edits_stylegan2, dump_path)


# StyleGAN2 - ffhq
# python3 run_stylegan.py -b 20 -n 1_000_000 --truncation_psi=0.7 --class ffhq
# apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path)
truncation_psi = 0.7
out_class = 'ffhq'

Gs, Gs_kwargs = load_network(out_class)
Gs_kwargs.truncation_psi = truncation_psi
dump_path = pca(Gs, 2, out_class, batch_size=batch_size, num_samples=num_samples)

edits_stylegan2 = ['Makeup']
apply_pca_fig7(Gs, Gs_kwargs, edits_stylegan2, dump_path)
