# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Project given image to the latent space of pretrained network pickle."""

import copy
import os
from time import perf_counter

import click
import imageio
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F

import dnnlib
import legacy
import pandas as pd
import torchvision
def project(
    G,
    target_attr,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    lambda_l1                  = None,
    lambda_lpips               = None,
    lambda_kl                 = None,
    mse                        = None,
    device: torch.device
):
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    """
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), target_attr.view(1,-1))  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
    """
    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        low_target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(low_target_images, resize_images=False, return_lpips=True)

    z_opt = torch.tensor(torch.randn([16,G.mapping.z_dim]), dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + [16,1024], dtype=torch.float32, device=device)
    #optimizer = torch.optim.Adam([z_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
    optimizer = torch.optim.Adam([z_opt]+list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in range(num_steps):
        # Learning rate schedule.
        t = step / num_steps
        z_noise_scale = 1.0 * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        #w_noise = torch.randn_like(w_opt) * w_noise_scale
        #ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        z_in = (z_opt+torch.randn_like(z_opt)*z_noise_scale)
        #z_in = z_opt
        ws = G.mapping(z_in, target_attr.repeat(z_in.shape[0],1))
        new_ws = [w[0] for w in ws]
        ws = torch.stack(new_ws, dim=0).unsqueeze(0)

        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            low_synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(low_synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()
        if mse:
            l1_loss = torch.mean((synth_images/255 - target_images/255)**2)
        else:
            l1_loss = torch.mean(torch.abs(synth_images/255 - target_images/255))

        #dist = ((synth_images - target_images)/255).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        #loss = dist + reg_loss * regularize_noise_weight
        normal_dist = torch.distributions.Normal(0,1)
        #loss_kl = 0.5*torch.sum(z_opt**2)
        loss_kl = (-0.5 * torch.sum(1 - z_opt.pow(2) + torch.log(1e-8 + z_opt.pow(2)), dim=1)).sum()

        loss =lambda_lpips* dist + lambda_l1*l1_loss + reg_loss * regularize_noise_weight + lambda_kl * loss_kl

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} reg {float(reg_loss):<5.2f} l1 {float(l1_loss):<5.2f} kl {float(loss_kl):<5.2f} loss {float(loss):<5.2f} lr {lr:.3f}')

        # Save projected W for each optimization step.
        w_out[step] = ws.detach()[0]

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()
    #w_out = w_out.unsqueeze(1)
    #return w_out.repeat([1, G.mapping.num_ws, 1])
    with torch.no_grad():
        intervene = target_attr.clone()
        intervene[0][21] = 1 #21
        ws = G.mapping(z_opt, intervene.repeat(z_in.shape[0],1))
        new_ws = [w[0] for w in ws]
        latest_ws = torch.stack(new_ws, dim=0).unsqueeze(0)

        intervene = target_attr.clone()
        intervene[0][87] = 1 #87
        ws = G.mapping(z_opt, intervene.repeat(z_in.shape[0],1))
        new_ws = [w[0] for w in ws]
        latest_ws_must = torch.stack(new_ws, dim=0).unsqueeze(0)

        intervene = target_attr.clone()
        intervene[0][68] = 1 # 68
        ws = G.mapping(z_opt, intervene.repeat(z_in.shape[0],1))
        new_ws = [w[0] for w in ws]
        latest_ws_goat = torch.stack(new_ws, dim=0).unsqueeze(0)


    return w_out, latest_ws, z_opt, latest_ws_must, latest_ws_goat

#----------------------------------------------------------------------------

@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
@click.option('--num-steps',              help='Number of optimization steps', type=int, default=1000, show_default=True)
@click.option('--seed',                   help='Random seed', type=int, default=303, show_default=True)
@click.option('--save-video',             help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
@click.option('--outdir',                 help='Where to save the output images', default='out', metavar='DIR')
@click.option('--lambda_l1',                 help='Where to save the output images', default=0, metavar='DIR', type=float)
@click.option('--lambda_lpips',                 help='Where to save the output images', default=1, metavar='DIR', type=float)
@click.option('--lambda_kl',                 help='Where to save the output images', default=0, metavar='DIR', type=float)
@click.option('--mse',                 help='Where to save the output images', default=1, metavar='DIR', type=int)
def run_projection(
    network_pkl: str,
    target_fname: str,
    outdir: str,
    save_video: bool,
    seed: int,
    num_steps: int
    ,lambda_l1: float
    ,lambda_lpips: float
    ,lambda_kl: float
    , mse: float
):
    """Project given image to the latent space of pretrained network pickle.

    Examples:

    \b
    python projector.py --outdir=out --target=~/mytargetimg.png \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load networks.
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore


    target_id = os.path.basename(target_fname)
    df = pd.read_csv('ffhq_aug_labels_boss300.csv')
    #df = pd.read_csv('ffhq_augeye_labels.csv')
    ids = df.to_numpy()[:,0]
    values = df.to_numpy()[:,1:]
    target_attr = values[ids==(target_id)]
    print('>?>>>>>>>>> YESS   ', target_attr.shape)
    if len(target_attr) != 0:
        target_attr = torch.tensor(target_attr.astype(np.float32)).to(device).view(1,-1)
    else:
        raise ValueError('Target not found in the attribute file')
        target_attr = torch.zeros([1,values.shape[1]]).to(device)
        target_attr[0,0] = 1


    # Load target image.
    target_pil = PIL.Image.open(target_fname).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)
    target_image = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device) # pylint: disable=not-callable

    #z = torch.randn([1, G.z_dim], generator=torch.Generator().manual_seed(1234)) # pylint: disable=not-callable
    #z = z.to(device)
    #G = G.to(device)
    #target_image = (torch.clip((G(z, target_attr)[0]+1)/2,0,1)*255).to(torch.uint8)
    #target_uint8 = target_image.cpu().numpy().squeeze().transpose([1, 2, 0])

    # Optimize projection.
    start_time = perf_counter()
    projected_w_steps, latest_ws, z_opt, latest_ws_must, latest_ws_goat = project(
        G,
        target_attr=target_attr,
        #target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
        target = target_image,
        num_steps=num_steps,
        device=device,
        lambda_l1=lambda_l1,
        lambda_lpips=lambda_lpips,
        lambda_kl=lambda_kl,
        mse=mse,
        verbose=True
    )
    print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

    # Render debug output: optional video and projected image and W vector.
    os.makedirs(outdir, exist_ok=True)
    postfix = f'_mse{mse}_l{lambda_l1}_lpips{lambda_lpips}_kl{lambda_kl}'
    if save_video and False:
        video = imageio.get_writer(f'{outdir}/proj_{postfix}.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
        print (f'Saving optimization progress video "{outdir}/proj.mp4"')
        for projected_w in projected_w_steps:
            synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
            synth_image = (synth_image + 1) * (255/2)
            synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
            video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
        video.close()

    # Save final projected frame and W vector.
    target_pil.save(f'{outdir}/target.png')
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = (synth_image + 1) / 2

    inter_image = G.synthesis(latest_ws, noise_mode='const')
    inter_image = (inter_image + 1) / 2

    inter2_image = G.synthesis(latest_ws_must, noise_mode='const')
    inter2_image = (inter2_image + 1) / 2

    inter3_image = G.synthesis(latest_ws_goat, noise_mode='const')
    inter3_image = (inter3_image + 1) / 2

    images = torch.stack([target_image/255, synth_image[0], inter_image[0], inter2_image[0], inter3_image[0]], dim=0)
    torchvision.utils.save_image(images, f'{outdir}/proj_{target_id}_{postfix}.png', nrow=5, normalize=False)

    #synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    #PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj_{postfix}.jpg')
    target_id = os.path.basename(target_fname).split('.')[0]
    np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
    np.savez(f'{outdir}/z_opt_{target_id}_{mse}_{lambda_l1}_{lambda_lpips}_{lambda_kl}.npz', z=z_opt.unsqueeze(0).detach().cpu().numpy())

#----------------------------------------------------------------------------

if __name__ == "__main__":
    run_projection() # pylint: disable=no-value-for-parameter

#----------------------------------------------------------------------------
