# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Patch for broken CTRL+C handler
# https://github.com/ContinuumIO/anaconda-issues/issues/905

import copy
import datetime
import os
import sys
from pathlib import Path

sys.path.append('./models/stylegan2')

import dnnlib
import dnnlib.tflib as tflib
import matplotlib.pyplot as plt
import pretrained_networks
from PIL import Image
from scipy.stats import special_ortho_group
from tqdm import trange

from estimators import get_estimator
from figure_configs import *
from utils import *

os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'

SEED_SAMPLING = 1
DEFAULT_BATCH_SIZE = 20

SEED_RANDOM_DIRS = 2
SEED_LINREG = 3
SEED_VISUALIZATION = 5

B = 20
n_clusters = 500


def get_random_dirs(components, dimensions):
    gen = np.random.RandomState(seed=SEED_RANDOM_DIRS)
    dirs = gen.normal(size=(components, dimensions))
    dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True))
    return dirs.astype(np.float32)


def load_network(out_class, model=2):
    network = out_classes[model][out_class]
    _G, _D, Gs = pretrained_networks.load_networks(network)

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False

    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

    rnd = np.random.RandomState(0)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})  # [height, width]

    return Gs, Gs_kwargs


def pca(Gs, stylegan_version, out_class, estimator='ipca', batch_size=20, num_components=80, num_samples=1_000_000, use_w=True, force_recompute=False, seed_compute=None):
    dump_name = "{}-{}_{}_c{}_n{}{}{}.npz".format(
        f'stylegan{stylegan_version}',
        out_class.replace(' ', '_'),
        # args.layer.lower(),
        estimator.lower(),
        num_components,
        num_samples,
        '_w' if use_w else '',
        f'_seed{seed_compute}' if seed_compute else ''
    )
    dump_path = Path(f'./cache/components/{dump_name}')
    if not dump_path.is_file() or force_recompute:
        os.makedirs(dump_path.parent, exist_ok=True)
        compute_pca(Gs, estimator, batch_size, num_components, num_samples, use_w, seed_compute, dump_path)

    return dump_path


def compute_pca(Gs, estimator, batch_size, num_components, num_samples, use_w, seed, dump_path):
    global B

    timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M")
    print(f'[{timestamp()}] Computing', dump_path.name)

    # Ensure reproducibility
    np.random.seed(0)

    # Regress back to w space
    if use_w:
        print('Using W latent space')

    sample_shape = Gs.components.mapping.run(np.random.randn(1, *Gs.input_shape[1:]), None, dlatent_broadcast=None).shape
    sample_dims = np.prod(sample_shape)
    print("Feature shape: ", sample_shape)
    print("Feature dims: ", sample_dims)

    input_shape = (1, *Gs.input_shape[1:])
    input_dims = np.prod(input_shape)

    components = min(num_components, sample_dims)
    transformer = get_estimator(estimator, components, 1.0)

    X = None
    X_global_mean = None

    # Figure out batch size if not provided
    B = batch_size or DEFAULT_BATCH_SIZE

    # Divisible by B (ignored in output name)
    N = num_samples // B * B

    w_avg = Gs.get_var('dlatent_avg')

    # Compute maximum batch size based on RAM + pagefile budget
    target_bytes = 20 * 1_000_000_000 # GB
    feat_size_bytes = sample_dims * np.dtype('float64').itemsize
    N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes)
    if not transformer.batch_support and N > N_limit_RAM:
        print('WARNING: estimator does not support batching, ' \
              'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N))

    print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True)

    # Must not depend on chosen batch size (reproducibility)
    NB = max(B, max(2_000, 3*components))  # ipca: as large as possible!

    samples = None
    if not transformer.batch_support:
        samples = np.zeros((N + NB, sample_dims), dtype=np.float32)

    np.random.seed(seed or SEED_SAMPLING)

    # Use exactly the same latents regardless of batch size
    # Store in main memory, since N might be huge (1M+)
    # Run in batches, since sample_latent() might perform Z -> W mapping
    n_lat = ((N + NB - 1) // B + 1) * B
    latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32)
    for i in trange(n_lat // B, desc='Sampling latents'):
        seed_global = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
        rng = np.random.RandomState(seed_global)
        # z = np.random.randn(B, *input_shape[1:])
        z = rng.standard_normal(512 * B).reshape(B, 512)
        if use_w:
            w = Gs.components.mapping.run(z, None, dlatent_broadcast=None)
            latents[i*B:(i+1)*B] = w
        else:
            latents[i*B:(i+1)*B] = z

    # Decomposition on non-Gaussian latent space
    # TODO: Need to fix this once the todo on --layer is completed
    # samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W'
    samples_are_latents = use_w

    canceled = False
    try:
        X = np.ones((NB, sample_dims), dtype=np.float32)
        action = 'Fitting' if transformer.batch_support else 'Collecting'
        for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
            for mb in range(0, NB, B):
                z = latents[gi+mb:gi+mb+B]

                if samples_are_latents:
                    # Decomposition on latents directly (e.g. StyleGAN W)
                    batch = z.reshape((B, -1))
                # TODO: Need to fix this once the todo on --layer is completed, Convert this to tf
                else:
                    pass
                    # # Decomposition on intermediate layer
                    # with torch.no_grad():
                    #     model.partial_forward(z, layer_key)

                    # # Permuted to place PCA dimensions last
                    # batch = inst.retained_features()[layer_key].reshape((B, -1))

                space_left = min(B, NB - mb)
                X[mb:mb+space_left] = batch[:space_left]

            if transformer.batch_support:
                if not transformer.fit_partial(X.reshape(-1, sample_dims)):
                    break
            else:
                samples[gi:gi+NB, :] = X.copy()
    except KeyboardInterrupt:
        if not transformer.batch_support:
            sys.exit(1)  # no progress yet

        dump_name = dump_path.parent / dump_path.name.replace(f'n{N}', f'n{gi}')
        print(f'Saving current state to "{dump_name.name}" before exiting')
        canceled = True

    if not transformer.batch_support:
        X = samples  # Use all samples
        X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32)  # TODO: activations surely multi-modal...!
        X -= X_global_mean

        print(f'[{timestamp()}] Fitting whole batch')
        t_start_fit = datetime.datetime.now()

        transformer.fit(X)

        print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
        assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
    else:
        X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims))
        X = X.reshape(-1, sample_dims)
        X -= X_global_mean

    X_comp, X_stdev, X_var_ratio = transformer.get_components()

    assert X_comp.shape[1] == sample_dims \
        and X_comp.shape[0] == components \
        and X_global_mean.shape[1] == sample_dims \
        and X_stdev.shape[0] == components, 'Invalid shape'

    # 'Activations' are really latents in a secondary latent space
    if samples_are_latents:
        Z_comp = X_comp
        Z_global_mean = X_global_mean
    else:
        # TODO: Implement regression
        pass
        # Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config)

    # Normalize
    Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True)

    # Random projections
    # We expect these to explain much less of the variance
    random_dirs = get_random_dirs(components, np.prod(sample_shape))
    n_rand_samples = min(5000, X.shape[0])
    X_view = X[:n_rand_samples, :].T
    assert np.shares_memory(X_view, X), "Error: slice produced copy"
    X_stdev_random = np.dot(random_dirs, X_view).std(axis=1)

    # Inflate back to proper shapes (for easier broadcasting)
    X_comp = X_comp.reshape(-1, *sample_shape)
    X_global_mean = X_global_mean.reshape(sample_shape)
    Z_comp = Z_comp.reshape(-1, *input_shape)
    Z_global_mean = Z_global_mean.reshape(input_shape)

    # Compute stdev in latent space if non-Gaussian
    lat_stdev = np.ones_like(X_stdev)
    if use_w:
        seed_global = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
        rng = np.random.RandomState(seed_global)
        # z = np.random.randn(B, *input_shape[1:])
        z = rng.standard_normal(512 * 5000).reshape(5000, 512)
        # z = np.random.randn(5000, *input_shape[1:])
        samples = Gs.components.mapping.run(z, None, dlatent_broadcast=None).reshape(5000, input_dims)
        # samples = w_avg + (samples - w_avg) * Gs_kwargs.truncation_psi
        coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T)
        lat_stdev = coords.std(axis=1)

    np.savez_compressed(dump_path, **{
        'act_comp': X_comp.astype(np.float32),
        'act_mean': X_global_mean.astype(np.float32),
        'act_stdev': X_stdev.astype(np.float32),
        'lat_comp': Z_comp.astype(np.float32),
        'lat_mean': Z_global_mean.astype(np.float32),
        'lat_stdev': lat_stdev.astype(np.float32),
        'var_ratio': X_var_ratio.astype(np.float32),
        'random_stdevs': X_stdev_random.astype(np.float32),
    })

    if canceled:
        sys.exit(1)

    del X
    del X_comp
    del random_dirs
    del batch
    del samples
    del latents


def apply_pca_fig1(Gs, truncation_psi, edits, seed, dump_path):
    with np.load(dump_path) as data:
        lat_comp = data['lat_comp']
        lat_mean = data['lat_mean']
        lat_std = data['lat_stdev']

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]
    
    rnd = np.random.RandomState(seed)
    z = rnd.standard_normal(input_shape * 1).reshape(1, input_shape)
    w = Gs.components.mapping.run(z, None)
    w = w.reshape((num_layers, 1, input_shape))

    pca_applied_ws = []
    pca_applied_ws.append(copy.deepcopy(w))
    for edit in edits:
        (idx, start, end, strength, invert) = configs_fig1[edit]

        # Find out coordinate of w along PC
        w_centered = w[0] - lat_mean
        w_coord = np.sum(w_centered.reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]

        # Invert property if desired (e.g. flip rotation)
        # Otherwise reinforce existing
        if invert:
            sign = w_coord / np.abs(w_coord)
            target = -sign*strength  # opposite side of mean
        else:
            target = strength

        delta = target - w_coord  # offset vector

        for l in range(start, end):
            w[l] = w[l] + lat_comp[idx]*lat_std[idx]*delta
        pca_applied_ws.append(copy.deepcopy(w))

    for i in range(len(pca_applied_ws)):
        pca_applied_ws[i] = pca_applied_ws[i].reshape((1, num_layers, input_shape))

    return pca_applied_ws


def apply_pca_grid_fig3_1(Gs, Gs_kwargs, truncation_psi, n_pcs, sigma, num_frames, layer_start, layer_end, seed, dump_path):
    with np.load(dump_path) as data:
        lat_comp = data['lat_comp']
        lat_mean = data['lat_mean']
        lat_std = data['lat_stdev']

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]
    
    rnd = np.random.RandomState(seed)
    z = rnd.randn(1, *Gs.input_shape[1:])

    strips = []
    
    for i in range(n_pcs):

        batch_frames = centre_strip_stylegan(Gs, Gs_kwargs, z, lat_comp, lat_mean, lat_std, i, sigma, num_frames, layer_start, layer_end)

        strips.append(np.hstack(batch_frames))

    grid = np.vstack(strips)

    plt.figure(figsize=(20, 40))
    plt.imshow(grid, interpolation='bilinear')
    plt.axis('off')
    plt.savefig('figure3_1.png')


def apply_pca_grid_fig3_2(Gs, Gs_kwargs, truncation_psi, num_frames, seed, dump_path):
    with np.load(dump_path) as data:
        lat_comp = data['lat_comp']
        lat_mean = data['lat_mean']
        lat_std = data['lat_stdev']

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]
    
    rnd = np.random.RandomState(seed)
    z = rnd.randn(1, *Gs.input_shape[1:])

    hand_tuned = [
     ( 0, (1,  7), 2.0),  # gender, keep age
     ( 1, (0,  3), 2.0),  # rotate, keep gender
     ( 2, (3,  8), 2.0),  # gender, keep geometry
     ( 3, (2,  8), 2.0),  # age, keep lighting, no hat
     ( 4, (5, 18), 2.0),  # background, keep geometry
     ( 5, (0,  4), 2.0),  # hat, keep lighting and age
     ( 6, (7, 18), 2.0),  # just lighting
     ( 7, (5,  9), 2.0),  # just lighting
     ( 8, (1,  7), 2.0),  # age, keep lighting
     ( 9, (0,  5), 2.0),  # keep lighting
     (10, (7,  9), 2.0),  # hair color, keep geom
     (11, (0,  5), 2.0),  # hair length, keep color
     (12, (8,  9), 2.0),  # light dir lr
     # (12, (4,  10), 2.0),  # light position LR
     (13, (0,  6), 2.0),  # about the same
    ]   

    strips = []
    
    for i, (s, e), sigma in hand_tuned:
  
        batch_frames = centre_strip_stylegan(Gs, Gs_kwargs, z, lat_comp, lat_mean, lat_std, i, sigma, num_frames, s, e)
        
        strips.append(np.hstack(batch_frames))

    grid = np.vstack(strips)

    plt.figure(figsize=(20, 40))
    plt.imshow(grid, interpolation='bilinear')
    plt.axis('off')
    plt.savefig('figure3_2.png')


def apply_pca_grid_fig4(Gs, Gs_kwargs, use_w, class_name, seed, dump_path, N=5, use_random_basis=True):
    model_name = 'StyleGAN2'
    # TODO: See what is to be done for BigGAN
    # if model_name == 'StyleGAN2':
    #     config.layer = 'style'
    # elif model_name == 'StyleGAN':
    #     config.layer = 'g_mapping'
    # else:
    #     config.layer = 'generator.gen_z'
    #     config.n = 1_000_000
    #     config.output_class = 'husky'

    outdir = Path('out/figures/random_baseline')
    os.makedirs(outdir, exist_ok=True)

    w_avg = Gs.get_var('dlatent_avg')

    num_layers = Gs.components.mapping.output_shape[1]

    input_shape = Gs.input_shape[1]
    K = np.prod(input_shape)

    with np.load(dump_path) as data:
        lat_comp = data['lat_comp']
        lat_mean = data['lat_mean']
        lat_std = data['lat_stdev']

    B = 6
    if seed is None:
        seed = np.random.randint(np.iinfo(np.int32).max - B)

    # TODO: See what is to be done for BigGAN
    # if 'BigGAN' in model_name:
    #     model.set_output_class(class_name)

    print(f'Seeds: {seed} - {seed+B}')

    # Resampling test
    rnd = np.random.RandomState(seed+B)
    w_base = rnd.randn(1, *Gs.input_shape[1:])
    if use_w:
        w_base = Gs.components.mapping.run(w_base, None, dlatent_broadcast=None)
    w_base_img = w_avg + (w_base - w_avg) * Gs_kwargs.truncation_psi
    imgs = Gs.components.synthesis.run(np.array([w_base_img]*num_layers).reshape((1, num_layers, input_shape)), **Gs_kwargs)
    plt.imshow(imgs[0])
    plt.axis('off')
    plt.title('Original')
    plt.savefig('figure4_original.png')

    # Project tensor 'X' onto orthonormal basis 'comp', return coordinates
    def project_ortho(X, comp):
        N = comp.shape[0]
        coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(axis=1)
        return coords.reshape([N]+[1]*X.ndim)

    # Resample some components
    def get_batch(indices, basis):
        w_batch = np.zeros((B, K))
        coord_base = project_ortho(w_base - lat_mean, basis)

        for i in range(B):
            rnd = np.random.RandomState(seed+i)
            w = rnd.randn(1, *Gs.input_shape[1:])
            if use_w:
                w = Gs.components.mapping.run(w, None, dlatent_broadcast=None)
            coords = coord_base.copy()
            coords_resampled = project_ortho(w - lat_mean, basis)
            coords[indices, :, :] = coords_resampled[indices, :, :]
            w_batch[i, :] = lat_mean + np.sum(coords * basis, axis=0)

        return w_batch

    def show_grid(w, title):
        w = np.expand_dims(w, axis=1)
        w = np.repeat(w, num_layers, axis=1)
        w = w_avg + (w - w_avg) * Gs_kwargs.truncation_psi
        out = Gs.components.synthesis.run(w, **Gs_kwargs)
        if class_name == 'car':
            out = out[:, :, 64:-64, :]
        elif class_name == 'cat':
            out = out[:, :, 18:-8, :]
        # grid = make_grid(out, nrow=3)
        # grid_np = grid.clamp(0, 1).permute(1, 2, 0).cpu().numpy()
        grid_np = np.hstack(out)
        plt.axis('off')
        plt.tight_layout()
        plt.title(title)
        plt.imshow(grid_np, interpolation='bilinear')
        plt.savefig('figure4_{}.png'.format(title.replace('->', '-')))

    def save_imgs(w, prefix):
        w = np.expand_dims(w, axis=1)
        w = np.repeat(w, num_layers, axis=1)
        w = w_avg + (w - w_avg) * Gs_kwargs.truncation_psi
        imgs = Gs.components.synthesis.run(w, **Gs_kwargs)
        for i, img in enumerate(imgs[0]):
            if class_name == 'car':
                img = img[64:-64, :, :]
            elif class_name == 'cat':
                img = img[18:-8, :, :]
            outpath = outdir / f'{model_name}_{class_name}' / f'{prefix}_{i}.png'
            os.makedirs(outpath.parent, exist_ok=True)
            Image.fromarray(np.uint8(img * 255)).save(outpath)
            # print('Saving', outpath)

    def orthogonalize_rows(V):
        Q, R = np.linalg.qr(V.T)
        return Q.T

    # V = [n_comp, n_dim]
    def assert_orthonormal(V):
        M = np.dot(V, V.T)  # [n_comp, n_comp]
        det = np.linalg.det(M)
        assert np.allclose(M, np.identity(M.shape[0]), atol=1e-5), f'Basis is not orthonormal (det={det})'

    plt.figure(figsize=((12, 6.5) if class_name in ['car', 'cat'] else (12, 8)))

    # First N fixed
    ind_rand = np.array(range(N, K))  # N -> K rerandomized
    b1 = get_batch(ind_rand, lat_comp)
    plt.subplot(2, 2, 1)
    show_grid(b1, f'Keep {N} first pca -> Consistent pose')
    save_imgs(b1, f'keep_{N}_first_{seed}')

    # First N randomized
    ind_rand = np.array(range(0, N))  # 0 -> N rerandomized
    b2 = get_batch(ind_rand, lat_comp)
    plt.subplot(2, 2, 2)
    show_grid(b2, f'Randomize {N} first pca -> Consistent style')
    save_imgs(b2, f'randomize_{N}_first_{seed}')

    if use_random_basis:
        # Random orthonormal basis drawn from p(w)
        # Highly shaped by W, sort of a noisy pseudo-PCA
        # V = (model.sample_latent(K, seed=seed + B + 1) - lat_mean).cpu().numpy()
        # V = V / np.sqrt(np.sum(V*V, axis=-1, keepdims=True)) # normalize rows
        # V = orthogonalize_rows(V)

        # Isotropic random basis
        V = special_ortho_group.rvs(K)
        assert_orthonormal(V)

        rand_basis = np.reshape(V, lat_comp.shape)
        assert rand_basis.shape == lat_comp.shape, f'Shape mismatch: {rand_basis.shape} != {lat_comp.shape}'

        ind_perm = range(K)
    else:
        # Just use shuffled PCA basis
        rng = np.random.RandomState(seed=seed)
        perm = rng.permutation(range(K))
        rand_basis = lat_comp[perm, :]

    basis_type_str = 'random' if use_random_basis else 'pca_shfl'

    # First N random fixed
    ind_rand = np.array(range(N, K))  # N -> K rerandomized
    b3 = get_batch(ind_rand, rand_basis)
    plt.subplot(2, 2, 3)
    show_grid(b3, f'Keep {N} first {basis_type_str} -> Little consistency')
    save_imgs(b3, f'keep_{N}_first_{basis_type_str}_{seed}')

    # First N random rerandomized
    ind_rand = np.array(range(0, N))  # 0 -> N rerandomized
    b4 = get_batch(ind_rand, rand_basis)
    plt.subplot(2, 2, 4)
    show_grid(b4, f'Randomize {N} first {basis_type_str} -> Little variation')
    save_imgs(b4, f'randomize_{N}_first_{basis_type_str}_{seed}')

    # plt.savefig('figure4.png')


def apply_pca_grid_fig5(Gs, Gs_kwargs, use_w, lat_mean, prefix, imgclass, seeds, d_ours, l_start, l_end, scale_ours, d_sup, scale_sup, center=True):
    out_root = Path('out/figures/steerability_comp')
    os.makedirs(out_root, exist_ok=True)
    os.makedirs(out_root / imgclass, exist_ok=True)

    normalize = lambda t: t / np.sqrt(np.sum(t.reshape(-1)**2))

    w_avg = Gs.get_var('dlatent_avg')

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]

    for seed in seeds:
        print("Seed:", seed)
        deltas = [d_ours, d_sup]
        scales = [scale_ours, scale_sup]
        # TODO: See what is to be done for BigGAN
        ranges = [(l_start, l_end), (0, num_layers)]
        names = ['ours', 'supervised']

        for delta, name, scale, l_range in zip(deltas, names, scales, ranges):
            np.random.seed(seed or SEED_SAMPLING)
            lat_base = np.random.randn(1, *Gs.input_shape[1:])
            if use_w:
                w = Gs.components.mapping.run(lat_base, None, dlatent_broadcast=None)
                lat_base = w

            # Shift latent to lie on mean along given direction
            if center:
                y = normalize(d_sup)  # assume ground truth
                dotp = np.sum((lat_base - lat_mean) * y, axis=-1, keepdims=True)
                lat_base = lat_base - dotp * y

            # Convert single delta to per-layer delta (to support Steerability StyleGAN)
            if delta.shape[0] > 1:
                # print('Unstacking delta')
                *d_per_layer, = delta  # might have per-layer scales, don't normalize
            else:
                # TODO: See what is to be done for BigGAN
                d_per_layer = [normalize(delta)]*num_layers

            frames = []
            n_frames = 5
            for a in np.linspace(-1.0, 1.0, n_frames):
                # TODO: See what is to be done for BigGAN
                w = [lat_base]*num_layers
                for l in range(l_range[0], l_range[1]):
                    w[l] = w[l] + a*d_per_layer[l]*scale

                w = np.array(w)
                w = w_avg + (w - w_avg) * Gs_kwargs.truncation_psi
                imgs = Gs.components.synthesis.run(w.reshape((1, num_layers, input_shape)), **Gs_kwargs)
                frames.append(imgs[0])

            # for i, frame in enumerate(frames):
            #     Image.fromarray(np.uint8(frame*255)).save(
            #         out_root / imgclass / f'{prefix}_{name}_{seed}_{i}.png')

            strip = np.hstack(pad_frames(frames, 64))
            plt.figure(figsize=(12, 12))
            plt.imshow(strip)
            plt.axis('off')
            plt.tight_layout()
            plt.title(f'{prefix} - {name}, scale={scale}')
            plt.savefig(f'{prefix}-{name}_scale={scale}.png')


def apply_pca_grid_fig7(Gs, Gs_kwargs, edits, dump_path):
    out_root = Path('out/directions')
    os.makedirs(out_root, exist_ok=True)
    B = 5

    num_imgs_per_example = 1

    input_shape = Gs.input_shape[1]
    num_layers = Gs.components.mapping.output_shape[1]

    grid = []
    strips = []

    configs = []
    for edit in edits:
        for c in configs_fig7:
            if c[9] == edit:
                configs.append(c)

    print("Configs:", configs)

    for config_id, (model_name, layer, mode, latent_space, l_start, l_end, classname, sigma, idx, title, seeds) in enumerate(configs[:]):
        print(f'{model_name}, {layer}, {title}')

        # TODO: See what is to be done for BigGAN
        # if 'BigGAN' in model_name:
        #     model.truncation = 0.6
        # elif 'StyleGAN2' in model_name:
        #     model.truncation = 0.7

        # TODO: See what is to be done for BigGAN
        # if latent_space == 'w':
        #     model.use_w()
        # elif hasattr(model, 'use_z'):
        #     model.use_z()

        # TODO: See what is to be done for BigGAN
        # # Special case: BigGAN512-deep, gen_z: class-independent
        # if model_name == 'BigGAN-512' and layer == 'generator.gen_z':
        #     config.output_class = 'husky'  # chosen class doesn't matter

        with np.load(dump_path) as data:
            X_comp = data['act_comp']
            X_global_mean = data['act_mean']
            X_stdev = data['act_stdev']
            Z_comp = data['lat_comp']
            Z_global_mean = data['lat_mean']
            Z_stdev = data['lat_stdev']

        feat_shape = X_comp[0].shape
        sample_dims = np.prod(feat_shape)

        # X_comp = np.reshape(X_comp, [-1, *feat_shape])  # -1, 1, C, H, W
        # X_global_mean = np.reshape(X_global_mean, [*feat_shape])  # 1, C, H, W

        # Creates additional seeds
        # num_seeds = ((num_imgs_per_example - 1) // B + 1) * B  # make divisible
        # max_seed = np.iinfo(np.int32).max
        # seeds = np.concatenate((seeds, np.random.randint(0, max_seed, num_seeds)))
        # seeds = seeds[:num_seeds].astype(np.int32)
        # latents = [model.sample_latent(1, seed=s) for s in seeds]

        # Range is exclusive, in contrast to notation in paper
        edit_start = l_start
        edit_end = num_layers if l_end == -1 else l_end

        # batch_frames = create_strip_centered(inst, mode, layer, latents, components.X_comp[idx],
        #                                      components.Z_comp[idx], components.X_stdev[idx], components.Z_stdev[idx],
        #                                      components.X_global_mean, components.Z_global_mean, sigma, edit_start, edit_end)
        # save_frames(f'{config_id}_{title}_{mode}', model_name, out_root, batch_frames)

        # TODO: we are using only the first seed - see if it is possible to stack multiple seeds
        print("Seeds:", seeds)
        for s in seeds:
            rnd = np.random.RandomState(s)
            z = rnd.randn(1, *Gs.input_shape[1:])

            batch_frames = centre_strip_stylegan(Gs, Gs_kwargs, z, Z_comp, Z_global_mean, Z_stdev, idx, sigma, 5, edit_start, edit_end)
            strips.append(np.hstack(batch_frames))

        # edit_name = prettify_name(title)
        # outidr = out_root / model_name / classname / edit_name
        # os.makedirs(outidr, exist_ok=True)

        # for ex, frames in enumerate(batch_frames):
        #     for i, frame in enumerate(frames):
        #         Image.fromarray(np.uint8(frame*255)).save(outidr / f'cmp{idx}_s{edit_start}_e{edit_end}_{seeds[ex]}_{i}.png')

    # Show first
    grid = np.vstack(strips)
    plt.figure(figsize=(15, 15))
    plt.imshow(grid, interpolation='bilinear')
    plt.axis('off')
    plt.savefig('figure7.png')
