"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.

This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""

import os
from os.path import join as ospj
import json
import glob
from shutil import copyfile

from tqdm import tqdm
import ffmpeg

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.utils as vutils
import pdb

from metrics.lpips import calculate_lpips_given_images
from core.data_loader import get_eval_loader


def save_json(json_file, filename):
    with open(filename, 'w') as f:
        json.dump(json_file, f, indent=4, sort_keys=False)


def print_network(network, name):
    num_params = 0
    for p in network.parameters():
        num_params += p.numel()
    # print(network)
    print("Number of parameters of %s: %i" % (name, num_params))


def he_init(module):
    if isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)


def denormalize(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)


def save_image(x, ncol, filename):
    x = denormalize(x)
    vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0)


@torch.no_grad()
def translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename):
    N, C, H, W = x_src.size()
    s_ref = nets.style_encoder(x_ref, y_ref)
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
    x_fake = nets.generator(x_src, s_ref, masks=masks)
    s_src = nets.style_encoder(x_src, y_src)
    masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None
    x_rec = nets.generator(x_fake, s_src, masks=masks)
    x_concat = [x_src, x_ref, x_fake, x_rec]
    x_concat = torch.cat(x_concat, dim=0)
    save_image(x_concat, N, filename)
    del x_concat

@torch.no_grad()
def translate_using_latent_yaxing_interpolation(nets, args, x_src, y_trg_list, z_trg_list,  index_cate, path, result_dir):
    N, C, H, W = x_src.size()
    latent_dim = z_trg_list[0].size(1)
    x_concat = [x_src]
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None

    for i, y_trg in enumerate(y_trg_list):
        print('Category: %s'%index_cate[int(y_trg[0].cpu().numpy())])
    #    z_many = torch.randn(10000, latent_dim).to(x_src.device)
    #    y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0])
    #    s_many = nets.mapping_network(z_many, y_many)
    #    s_avg = torch.mean(s_many, dim=0, keepdim=True)
    #    s_avg = s_avg.repeat(N, 1)
        for z_trg in z_trg_list:
            import pdb;pdb.set_trace()
            s_trg1 = nets.mapping_network(z_trg, y_trg+103)
            s_trg2 = nets.mapping_network(z_trg, y_trg+147)
            #s_trg = torch.lerp(s_avg, s_trg, psi)
            for jj in range(10):
                s_trg = s_trg2 * (jj*1./10) + s_trg1 * (1.0 - jj*1./10) 
                x_fake = nets.generator(x_src, s_trg, masks=masks)
                x_concat += [x_fake]
                if not os.path.exists(os.path.join(result_dir, index_cate[int(y_trg[0].cpu().numpy())])):

                    os.makedirs(os.path.join(result_dir, index_cate[int(y_trg[0].cpu().numpy())]))
                for img_index in range(N):


                    img_name = path[img_index].split('/')[-1].split('.')[0] + '_%d.'%(jj) + path[img_index].split('/')[-1].split('.')[1] 
                    if not os.path.exists(os.path.join(result_dir, index_cate[int(y_trg[0].cpu().numpy())])):
                        os.makedirs(os.path.join(result_dir, index_cate[int(y_trg[0].cpu().numpy())]))
	                
                    filename = os.path.join(result_dir, index_cate[int(y_trg[0].cpu().numpy())], img_name)
                    save_image(x_fake[img_index], 1, filename)

@torch.no_grad()
def translate_using_latent_yaxing(nets, args, x_src, y_trg_list, z_trg_list,  index_cate, path, result_dir, sampling_index, LPIPS_latent=None, LPIPS_refer=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    N, C, H, W = x_src.size()
    latent_dim = z_trg_list[0].size(1)
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
    psi=1.0
    if LPIPS_refer is not None:
        for class_index in range(args.num_domains):
            y_trg = torch.tensor(class_index).repeat(N).to(device)
            target_cat=index_cate[class_index]
            path_ref = os.path.join(args.ref_dir, target_cat)
            # since we have little image for each image, we sample the reference with batch 1 and stack
            loader_ref = get_eval_loader(root=path_ref,
                                         img_size=args.img_size,
                                         batch_size=1,
                                         imagenet_normalize=False,
                                         drop_last=False)

            group_of_images = []
            for j in range(args.num_outs_per_domain):
                try:
                    x_ref = next(iter_ref).to(device)
                except:
                    iter_ref = iter(loader_ref)
                    x_ref = next(iter_ref).to(device)

                x_ref = torch.cat(N*[x_ref])
                s_trg = nets.style_encoder(x_ref, y_trg)
                x_fake = nets.generator(x_src, s_trg, masks=masks)
                group_of_images.append(x_fake)

                # Only saving one output image for one input image for computing FID
                if not os.path.exists(os.path.join(result_dir, 'ref_output', target_cat)):
                    os.makedirs(os.path.join(result_dir, 'ref_output', target_cat))
                for img_index in range(N):
                    img_name = path[img_index].split('/')[-1] 
                    filename = os.path.join(result_dir, 'ref_output', target_cat, img_name[:-4]+'_To_' + target_cat + '_%d_%d.jpg'%(sampling_index, j))
                    save_image(x_fake[img_index], 1, filename)
            # Saving  diversity output images for one input image for computing FID, but we only save the first  sampling
            if sampling_index==0:
                if not os.path.exists(os.path.join(result_dir,'ref_diversity', target_cat)):
                    os.makedirs(os.path.join(result_dir, 'ref_diversity', target_cat))
                for noise_index, x_fake in enumerate(group_of_images):
                    for img_index in range(min(x_fake.size()[0], 32)):
                        img_name = path[img_index].split('/')[-1] 
                        filename = os.path.join(result_dir, 'ref_diversity', target_cat, img_name[:-4]+'_To_' + target_cat + '_%d_.jpg'%(noise_index))
                        save_image(x_fake[img_index], 1, filename)
            # computing LPIPS metrics
            LPIPS_refer[target_cat] +=[calculate_lpips_given_images(group_of_images)]
        return LPIPS_refer

    else:
        for i, y_trg in enumerate(y_trg_list):
            print('Category: %s'%index_cate[int(y_trg[0].cpu().numpy())])
            group_of_images = []
            for z_index, z_trg in enumerate(z_trg_list):
                s_trg = nets.mapping_network(z_trg, y_trg)
                x_fake = nets.generator(x_src, s_trg, masks=masks)
                group_of_images.append(x_fake)


                # Only saving one output image for one input image for computing FID
                if z_index==0:
                    for img_index in range(N):
                        img_name = path[img_index].split('/')[-1] 
                        target_cat=index_cate[int(y_trg[0].cpu().numpy())]
                        if not os.path.exists(os.path.join(result_dir, 'latent_output', target_cat)):
                            os.makedirs(os.path.join(result_dir, 'latent_output', target_cat))
                        filename = os.path.join(result_dir, 'latent_output', target_cat, img_name[:-4]+'_To_' + target_cat +'_%d_.jpg'%(sampling_index))
                        save_image(x_fake[img_index], 1, filename)
                    # Saving the input image
                    if i==0:
                        save_input_path = result_dir+'/input'
                        if not os.path.exists(save_input_path):
                            os.makedirs(save_input_path)
                        for img_index in range(N):
                            img_name = path[img_index].split('/')[-1] 
                            filename = os.path.join(save_input_path, img_name)
                            save_image(x_src[img_index], 1, filename)
                  
            # Saving  diversity output images for one input image for computing FID, but we only save the first  sampling
            if sampling_index==0:
                for noise_index, x_fake in enumerate(group_of_images):
                    for img_index in range(min(N, 32)):
                        img_name = path[img_index].split('/')[-1] 
                        target_cat=index_cate[int(y_trg[0].cpu().numpy())]
                        if not os.path.exists(os.path.join(result_dir,'latent_diversity', target_cat)):
                            os.makedirs(os.path.join(result_dir, 'latent_diversity', target_cat))
                        filename = os.path.join(result_dir, 'latent_diversity', target_cat, img_name[:-4]+'_To_' + target_cat + '_%d_.jpg'%(noise_index))
                        save_image(x_fake[img_index], 1, filename)

            # computing LPIPS metrics
            target_cat=index_cate[int(y_trg[0].cpu().numpy())]
            LPIPS_latent[target_cat] +=[calculate_lpips_given_images(group_of_images)]
        return LPIPS_latent 

                



@torch.no_grad()
def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename):
    N, C, H, W = x_src.size()
    latent_dim = z_trg_list[0].size(1)
    x_concat = [x_src]
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None

    for i, y_trg in enumerate(y_trg_list):
        z_many = torch.randn(10000, latent_dim).to(x_src.device)
        y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0])
        s_many = nets.mapping_network(z_many, y_many)
        s_avg = torch.mean(s_many, dim=0, keepdim=True)
        s_avg = s_avg.repeat(N, 1)

        for z_trg in z_trg_list:
            s_trg = nets.mapping_network(z_trg, y_trg)
            s_trg = torch.lerp(s_avg, s_trg, psi)
            x_fake = nets.generator(x_src, s_trg, masks=masks)
            x_concat += [x_fake]

    x_concat = torch.cat(x_concat, dim=0)
    save_image(x_concat, N, filename)


@torch.no_grad()
def translate_using_reference(nets, args, x_src, x_ref, y_ref, filename):
    N, C, H, W = x_src.size()
    wb = torch.ones(1, C, H, W).to(x_src.device)
    x_src_with_wb = torch.cat([wb, x_src], dim=0)

    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
    s_ref = nets.style_encoder(x_ref, y_ref)
    s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1)
    x_concat = [x_src_with_wb]
    for i, s_ref in enumerate(s_ref_list):
        x_fake = nets.generator(x_src, s_ref, masks=masks)
        x_fake_with_ref = torch.cat([x_ref[i:i+1], x_fake], dim=0)
        x_concat += [x_fake_with_ref]

    x_concat = torch.cat(x_concat, dim=0)
    save_image(x_concat, N+1, filename)
    del x_concat


@torch.no_grad()
def debug_image(nets, args, inputs, step):
    x_src, y_src = inputs.x_src, inputs.y_src
    x_ref, y_ref = inputs.x_ref, inputs.y_ref

    device = inputs.x_src.device
    N = inputs.x_src.size(0)

    # translate and reconstruct (reference-guided)
    filename = ospj(args.sample_dir, '%06d_cycle_consistency.jpg' % (step))
    translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename)

    # latent-guided image synthesis
    y_trg_list = [torch.tensor(y).repeat(N).to(device)
                  for y in range(min(args.num_domains, 5))]
    z_trg_list = torch.randn(args.num_outs_per_domain, 1, args.latent_dim).repeat(1, N, 1).to(device)
    for psi in [0.5, 0.7, 1.0]:
        filename = ospj(args.sample_dir, '%06d_latent_psi_%.1f.jpg' % (step, psi))
        translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename)

    # reference-guided image synthesis
    filename = ospj(args.sample_dir, '%06d_reference.jpg' % (step))
    translate_using_reference(nets, args, x_src, x_ref, y_ref, filename)


# ======================= #
# Video-related functions #
# ======================= #


def sigmoid(x, w=1):
    return 1. / (1 + np.exp(-w * x))


def get_alphas(start=-5, end=5, step=0.5, len_tail=10):
    return [0] + [sigmoid(alpha) for alpha in np.arange(start, end, step)] + [1] * len_tail


def interpolate(nets, args, x_src, s_prev, s_next):
    ''' returns T x C x H x W '''
    B = x_src.size(0)
    frames = []
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
    alphas = get_alphas()

    for alpha in alphas:
        s_ref = torch.lerp(s_prev, s_next, alpha)
        x_fake = nets.generator(x_src, s_ref, masks=masks)
        entries = torch.cat([x_src.cpu(), x_fake.cpu()], dim=2)
        frame = torchvision.utils.make_grid(entries, nrow=B, padding=0, pad_value=-1).unsqueeze(0)
        frames.append(frame)
    frames = torch.cat(frames)
    return frames


def slide(entries, margin=32):
    """Returns a sliding reference window.
    Args:
        entries: a list containing two reference images, x_prev and x_next, 
                 both of which has a shape (1, 3, 256, 256)
    Returns:
        canvas: output slide of shape (num_frames, 3, 256*2, 256+margin)
    """
    _, C, H, W = entries[0].shape
    alphas = get_alphas()
    T = len(alphas) # number of frames

    canvas = - torch.ones((T, C, H*2, W + margin))
    merged = torch.cat(entries, dim=2)  # (1, 3, 512, 256)
    for t, alpha in enumerate(alphas):
        top = int(H * (1 - alpha))  # top, bottom for canvas
        bottom = H * 2
        m_top = 0  # top, bottom for merged
        m_bottom = 2 * H - top
        canvas[t, :, top:bottom, :W] = merged[:, :, m_top:m_bottom, :]
    return canvas


@torch.no_grad()
def video_ref(nets, args, x_src, x_ref, y_ref, fname):
    video = []
    s_ref = nets.style_encoder(x_ref, y_ref)
    s_prev = None
    for data_next in tqdm(zip(x_ref, y_ref, s_ref), 'video_ref', len(x_ref)):
        x_next, y_next, s_next = [d.unsqueeze(0) for d in data_next]
        if s_prev is None:
            x_prev, y_prev, s_prev = x_next, y_next, s_next
            continue
        if y_prev != y_next:
            x_prev, y_prev, s_prev = x_next, y_next, s_next
            continue

        interpolated = interpolate(nets, args, x_src, s_prev, s_next)
        entries = [x_prev, x_next]
        slided = slide(entries)  # (T, C, 256*2, 256)
        frames = torch.cat([slided, interpolated], dim=3).cpu()  # (T, C, 256*2, 256*(batch+1))
        video.append(frames)
        x_prev, y_prev, s_prev = x_next, y_next, s_next

    # append last frame 10 time
    for _ in range(10):
        video.append(frames[-1:])
    video = tensor2ndarray255(torch.cat(video))
    save_video(fname, video)


@torch.no_grad()
def video_latent(nets, args, x_src, y_list, z_list, psi, fname):
    latent_dim = z_list[0].size(1)
    s_list = []
    for i, y_trg in enumerate(y_list):
        z_many = torch.randn(10000, latent_dim).to(x_src.device)
        y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0])
        s_many = nets.mapping_network(z_many, y_many)
        s_avg = torch.mean(s_many, dim=0, keepdim=True)
        s_avg = s_avg.repeat(x_src.size(0), 1)

        for z_trg in z_list:
            s_trg = nets.mapping_network(z_trg, y_trg)
            s_trg = torch.lerp(s_avg, s_trg, psi)
            s_list.append(s_trg)

    s_prev = None
    video = []
    # fetch reference images
    for idx_ref, s_next in enumerate(tqdm(s_list, 'video_latent', len(s_list))):
        if s_prev is None:
            s_prev = s_next
            continue
        if idx_ref % len(z_list) == 0:
            s_prev = s_next
            continue
        frames = interpolate(nets, args, x_src, s_prev, s_next).cpu()
        video.append(frames)
        s_prev = s_next
    for _ in range(10):
        video.append(frames[-1:])
    video = tensor2ndarray255(torch.cat(video))
    save_video(fname, video)


def save_video(fname, images, output_fps=30, vcodec='libx264', filters=''):
    assert isinstance(images, np.ndarray), "images should be np.array: NHWC"
    num_frames, height, width, channels = images.shape
    stream = ffmpeg.input('pipe:', format='rawvideo', 
                          pix_fmt='rgb24', s='{}x{}'.format(width, height))
    stream = ffmpeg.filter(stream, 'setpts', '2*PTS')  # 2*PTS is for slower playback
    stream = ffmpeg.output(stream, fname, pix_fmt='yuv420p', vcodec=vcodec, r=output_fps)
    stream = ffmpeg.overwrite_output(stream)
    process = ffmpeg.run_async(stream, pipe_stdin=True)
    for frame in tqdm(images, desc='writing video to %s' % fname):
        process.stdin.write(frame.astype(np.uint8).tobytes())
    process.stdin.close()
    process.wait()


def tensor2ndarray255(images):
    images = torch.clamp(images * 0.5 + 0.5, 0, 1)
    return images.cpu().numpy().transpose(0, 2, 3, 1) * 255
