﻿# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import numpy as np
import torch
from torch_utils import training_stats
from torch_utils import misc
from torch_utils.ops import conv2d_gradfix

#----------------------------------------------------------------------------

class Loss:
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain, lambda_sparse): # to be overridden by subclass
        raise NotImplementedError()

#----------------------------------------------------------------------------

class StyleGAN2Loss(Loss):
    def __init__(self, device, G_mapping, G_synthesis, D, D_v, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, vid_length=16, img_channels=3):
        super().__init__()
        self.device = device
        self.G_mapping = G_mapping
        self.G_synthesis = G_synthesis
        self.D = D
        self.D_v = D_v
        self.augment_pipe = augment_pipe
        self.style_mixing_prob = style_mixing_prob
        self.r1_gamma = r1_gamma
        self.pl_batch_shrink = pl_batch_shrink
        self.pl_decay = pl_decay
        self.pl_weight = pl_weight
        self.pl_mean = torch.zeros([], device=device)
        self.vid_length = vid_length
        self.img_channels = img_channels
        self.decay_rate = 0.99
        self.decay_step = 100

    def get_decay_value(self, global_step):
        return max(0.1, self.decay_rate ** (global_step // self.decay_step))

    
    def run_G(self, ec, es, sync, sparse_loss=False):
        with misc.ddp_sync(self.G_mapping, sync):
            if sparse_loss:
                ws, loss_sparse = self.G_mapping(ec, es, sparse_loss=True)
            else:
                ws = self.G_mapping(ec, es)
            # if self.style_mixing_prob > 0:
            #     assert (0)
        with misc.ddp_sync(self.G_synthesis, sync):
            img = self.G_synthesis(ws)
            img = img.view(-1, self.vid_length, self.img_channels, img.shape[3], img.shape[4])
        if sparse_loss:
            return img, ws, loss_sparse
        else:
            return img, ws

    def run_D(self, img, sync, mode):
        if self.augment_pipe is not None:
            B, T, C, H, W = img.shape
            img_aug = []
            for b in range(B):
                b_img = img[b]
                b_img_aug = self.augment_pipe(b_img, batchsame=True)
                img_aug.append(b_img_aug)
            img_aug = torch.stack(img_aug, dim=0) # B, T, C, H, W
            img = img_aug.view(-1, C, H, W)
            vid = img_aug.view(B, -1, H, W)

        with misc.ddp_sync(self.D, sync):
            if mode in ['image','both']: # image
                logits = self.D(img)
            if mode in ['video','both']:
                logits_v = self.D_v(vid)
        if mode == 'image':
            return logits
        elif mode == 'video':
            return logits_v
        elif mode == 'both':
            return logits, logits_v

    def accumulate_gradients(self, phase, real_img, gen_ec, gen_es, sync, gain, lambda_sparse, gstep):
        # print('phase:', phase)
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth', 'D_vmain', 'D_vreg', 'D_vboth']
        do_Gmain   = (phase in ['Gmain', 'Gboth'])
        do_Dmain   = (phase in ['Dmain', 'Dboth'])
        do_D_vmain = (phase in ['D_vmain', 'D_vboth'])
        do_Gpl     = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
        do_Dr1     = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
        do_D_vr1   = (phase in ['D_vreg', 'D_vboth']) and (self.r1_gamma != 0)
        # decay_value = self.get_decay_value(gstep)
        decay_value = 0.3
        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, _gen_ws, loss_sparse = self.run_G(gen_ec, gen_es, sync=(sync and not do_Gpl), sparse_loss=True) # May get synced by Gpl.
                gen_logits, gen_logits_v = self.run_D(gen_img, sync=False, mode='both')
                training_stats.report('Loss/ratio/decay_value', decay_value)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/scores/fake_v', gen_logits_v)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                training_stats.report('Loss/signs/fake_v', gen_logits_v.sign())
                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                loss_Gmain_v = torch.nn.functional.softplus(-gen_logits_v) # -log(sigmoid(gen_logits_v))
                training_stats.report('Loss/G/sparse', loss_sparse)
                loss_Gmain = loss_Gmain
                loss_Gmain_v = loss_Gmain_v + lambda_sparse * loss_sparse
                training_stats.report('Loss/G/loss', loss_Gmain)
                training_stats.report('Loss/G/loss_v', loss_Gmain_v)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                (
                    loss_Gmain.mean().mul(gain*decay_value).add(
                    loss_Gmain_v.mean().mul(gain))
                ).backward()
        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = gen_ec.shape[0] // self.pl_batch_shrink
                # batch_size = gen_ec.shape[0]
                gen_img, gen_ws = self.run_G(gen_ec[:batch_size], gen_es[:batch_size], sync=sync)
                gen_img = gen_img.view(-1, gen_img.shape[2], gen_img.shape[3], gen_img.shape[4])
                # gen_ws = gen_ws.view(-1, gen_ws.shape[2], gen_ws.shape[3])
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)
            with torch.autograd.profiler.record_function('Gpl_backward'):
                (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_ec, gen_es, sync=False)
                gen_logits = self.run_D(gen_img, sync=False, mode='image') # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                (loss_Dgen.mean().mul(gain*decay_value)).backward()

        loss_Dgen_v = 0
        if do_D_vmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_ec, gen_es, sync=False)
                gen_logits_v = self.run_D(gen_img, sync=False, mode='video') # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake_v', gen_logits_v)
                training_stats.report('Loss/signs/fake_v', gen_logits_v.sign())
                loss_Dgen_v = torch.nn.functional.softplus(gen_logits_v) # -log(1 - sigmoid(gen_logits_v))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                (loss_Dgen_v.mean().mul(gain)).backward()
                

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, sync=sync, mode='image')
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([2,3,4]).reshape(-1)
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                ((real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain*decay_value)).backward()


        if do_D_vmain or do_D_vr1:
            name = 'Dreal_D_vr1' if do_D_vmain and do_D_vr1 else 'D_vreal' if do_D_vmain else 'D_vr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_D_vr1)
                real_logits_v = self.run_D(real_img_tmp, sync=sync, mode='video')
                training_stats.report('Loss/scores/real_v', real_logits_v)
                training_stats.report('Loss/signs/real_v', real_logits_v.sign())

                loss_Dreal_v = 0
                if do_D_vmain:
                    loss_Dreal_v = torch.nn.functional.softplus(-real_logits_v) # -log(sigmoid(real_logits_v))
                    training_stats.report('Loss/D/loss_v', loss_Dgen_v + loss_Dreal_v)

                loss_Dr1_v = 0
                if do_D_vr1:
                    with torch.autograd.profiler.record_function('r1_grads_v'), conv2d_gradfix.no_weight_gradients():
                        r1_grads_v = torch.autograd.grad(outputs=[real_logits_v.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
                    r1_penalty_v = r1_grads_v.square().sum([2,3,4]).reshape(-1)
                    loss_Dr1_v = r1_penalty_v * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty_v', r1_penalty_v)
                    training_stats.report('Loss/D/reg_v', loss_Dr1_v)

            with torch.autograd.profiler.record_function(name + '_backward'):
                ((real_logits_v * 0 + loss_Dreal_v + loss_Dr1_v).mean().mul(gain)).backward()
                

#----------------------------------------------------------------------------
