# python3.7
"""Defines loss functions for EG3D training."""

import numpy as np
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast

from utils.dist_utils import ddp_sync
from .base_loss import BaseLoss

from third_party.stylegan3_official_ops import upfirdn2d
from third_party.stylegan3_official_ops import conv2d_gradfix

__all__ = ['EpiGRAFLoss']


class EpiGRAFLoss(BaseLoss):
    """Contains the class to compute EG3D loss."""

    def __init__(self, runner, d_loss_kwargs=None, g_loss_kwargs=None):
        """Initializes with models and arguments for computing losses."""
        # Setting for discriminator loss.
        self.device = runner.device
        self.d_loss_kwargs = d_loss_kwargs or dict()
        self.r1_gamma = self.d_loss_kwargs.get('r1_gamma', 10.0)
        self.r1_interval = self.d_loss_kwargs.get('r1_interval', 16)
        if self.r1_gamma is None or self.r1_interval <= 0:
            self.r1_interval = 1
            self.r1_gamma = 0.0
        self.r1_interval = int(self.r1_interval)
        assert self.r1_gamma >= 0.0
        self.blur_init_sigma = self.d_loss_kwargs.get('blur_init_sigma', 10)
        self.blur_fade_kimg = self.d_loss_kwargs.get('blur_fade_kimg', 200)
        self.filter_mode = self.d_loss_kwargs.get('filter_mode', 'antialiased')
        self.blur_raw_target = self.d_loss_kwargs.get('blur_raw_target', True)
        self.resample_filter = upfirdn2d.setup_filter([1, 3, 3, 1],
                                                      device=self.device)

        runner.running_stats.add('Loss/D Fake',
                                 log_name='loss_d_fake',
                                 log_format='.3f',
                                 log_strategy='AVERAGE')
        runner.running_stats.add('Loss/D Real',
                                 log_name='loss_d_real',
                                 log_format='.3f',
                                 log_strategy='AVERAGE')
        if self.r1_gamma > 0.0:
            runner.running_stats.add('Loss/Real Gradient Penalty',
                                     log_name='loss_gp',
                                     log_format='.1e',
                                     log_strategy='AVERAGE')

        # Setting for generator loss.
        self.g_loss_kwargs = g_loss_kwargs or dict()
        self.pl_batch_shrink = int(self.g_loss_kwargs.get(
            'pl_batch_shrink', 2))
        self.pl_weight = self.g_loss_kwargs.get('pl_weight', 2.0)
        self.pl_decay = self.g_loss_kwargs.get('pl_decay', 0.01)
        self.pl_interval = self.g_loss_kwargs.get('pl_interval', 4)
        if self.pl_interval is None or self.pl_interval <= 0:
            self.pl_interval = 1
            self.pl_weight = 0.0
        self.pl_interval = int(self.pl_interval)
        assert self.pl_batch_shrink >= 1
        assert self.pl_weight >= 0.0
        assert 0.0 <= self.pl_decay <= 1.0
        self.rendering_resolution_patch = self.g_loss_kwargs.get(
            'rendering_resolution_patch', 64)

        runner.running_stats.add('Loss/G',
                                 log_name='loss_g',
                                 log_format='.3f',
                                 log_strategy='AVERAGE')
        if self.pl_weight > 0.0:
            runner.running_stats.add('Loss/Path Length Penalty',
                                     log_name='loss_pl',
                                     log_format='.1e',
                                     log_strategy='AVERAGE')
            self.pl_mean = torch.zeros((), device=runner.device)

    @staticmethod
    def run_G(runner,
              label,
              patch_params,
              batch_size=None,
              update_emas=False,
              requires_grad=False,
              sync=True):
        """Forwards generator."""
        batch_size = batch_size or runner.batch_size
        latent_dim = runner.models['generator'].z_dim
        latents = torch.randn((batch_size, latent_dim),
                              device=runner.device,
                              requires_grad=requires_grad)
        G = runner.ddp_models['generator']
        label_gen_conditioning = label.clone()  # [N, 3]
        camera_spoof_idx = torch.rand(
            label_gen_conditioning.shape[0]
        ) < runner.models['generator'].gpc_spoof_p  # [N]
        label_gen_conditioning[camera_spoof_idx] = label_gen_conditioning[
            camera_spoof_idx].roll(shifts=1, dims=0)

        with ddp_sync(G, sync=sync):
            results = G(latents,
                        label,
                        patch_params,
                        label_swapped=label_gen_conditioning,
                        style_mixing_prob=runner.config.style_mixing_prob,
                        update_emas=update_emas)

        return results

    @staticmethod
    def run_D(runner,
              img,
              label,
              patch_params,
              blur_sigma=0,
              update_emas=False,
              sync=True):
        D = runner.ddp_models['discriminator']
        blur_size = np.floor(blur_sigma * 3)
        if blur_size > 0:
            f = torch.arange(
                -blur_size, blur_size + 1,
                device=img.device).div(blur_sigma).square().neg().exp2()
            img = upfirdn2d.filter2d(img, f / f.sum())
        if runner.config.use_ada:
            img = runner.augment(img, **runner.augment_kwargs)
        with ddp_sync(D, sync=sync):
            scores = D(img, label, patch_params, update_emas=update_emas)
        return scores

    @staticmethod
    def compute_grad_penalty(images, scores):
        """Computes gradient penalty."""
        with conv2d_gradfix.no_weight_gradients():
            image_grad = torch.autograd.grad(outputs=[scores.sum()],
                                                inputs=[images],
                                                create_graph=True,
                                                only_inputs=True)[0]
        grad_penalty = image_grad.square().sum([1,2,3])
        return grad_penalty

    def compute_pl_penalty(self, images, latents):
        """Computes perceptual path length penalty."""
        res_h, res_w = images.shape[2:4]
        pl_noise = torch.randn_like(images) / np.sqrt(res_h * res_w)
        with conv2d_gradfix.no_weight_gradients():
            code_grad = torch.autograd.grad(
                outputs=[(images * pl_noise).sum()],
                inputs=[latents],
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
        pl_length = code_grad.square().sum(2).mean(1).sqrt()
        pl_mean = self.pl_mean.lerp(pl_length.mean(), self.pl_decay)
        self.pl_mean.copy_(pl_mean.detach())
        pl_penalty = (pl_length - pl_mean).square()
        return pl_penalty

    def d_fake_loss(self,
                    runner,
                    fake_labels,
                    patch_params,
                    blur_sigma,
                    sync=True):
        """Computes discriminator loss on fake/generated images."""
        # Train with fake/generated samples.
        fake_imgs = self.run_G(runner,
                               fake_labels,
                               patch_params,
                               update_emas=True,
                               sync=False)['image']
        fake_scores = self.run_D(runner,
                                 fake_imgs,
                                 fake_labels,
                                 patch_params,
                                 blur_sigma,
                                 update_emas=True,
                                 sync=sync)
        d_fake_loss = F.softplus(fake_scores)
        runner.running_stats.update({'Loss/D Fake': d_fake_loss})
        return d_fake_loss.mean()

    def d_real_loss(self,
                    runner,
                    real_img,
                    real_labels,
                    patch_real,
                    blur_sigma,
                    sync=True):
        # Train with real samples.
        real_img_tmp_image = real_img['image'].detach().requires_grad_(False)
        real_scores = self.run_D(runner,
                                 real_img_tmp_image,
                                 real_labels,
                                 patch_real,
                                 blur_sigma,
                                 sync=sync)
        d_real_loss = F.softplus(-real_scores)
        runner.running_stats.update({'Loss/D Real': d_real_loss})

        # Adjust the augmentation strength if needed.
        if hasattr(runner.augment, 'prob_tracker'):
            runner.augment.prob_tracker.update(real_scores.sign())

        return d_real_loss.mean()

    def d_reg(self,
              runner,
              real_img,
              real_labels,
              patch_real,
              blur_sigma,
              sync=True):
        """Compute the regularization loss for discriminator."""
        if runner.iter % self.r1_interval != 1 or self.r1_gamma == 0.0:
            return None

        real_img_tmp_image = real_img['image'].detach().requires_grad_(True)
        real_scores = self.run_D(runner,
                                 real_img_tmp_image,
                                 real_labels,
                                 patch_real,
                                 blur_sigma,
                                 sync=sync)
        r1_penalty = self.compute_grad_penalty(real_img_tmp_image, real_scores)
        runner.running_stats.update({'Loss/Real Gradient Penalty': r1_penalty})
        r1_penalty = r1_penalty * (self.r1_gamma * 0.5) * self.r1_interval

        # Adjust the augmentation strength if needed.
        if hasattr(runner.augment, 'prob_tracker'):
            runner.augment.prob_tracker.update(real_scores.sign())

        return (real_scores * 0 + r1_penalty).mean()

    def g_loss(self,
               runner,
               fake_labels,
               patch_params,
               blur_sigma,
               sync=True):
        """Computes loss for generator."""
        fake_imgs = self.run_G(runner,
                               fake_labels,
                               patch_params,
                               sync=sync)['image']
        fake_scores = self.run_D(runner,
                                 fake_imgs,
                                 fake_labels,
                                 patch_params,
                                 blur_sigma,
                                 sync=False)
        g_loss = F.softplus(-fake_scores)
        runner.running_stats.update({'Loss/G': g_loss})

        return g_loss.mean()

    def g_reg(self, runner, _data, sync=True):
        """Computes the regularization loss for generator."""
        if runner.iter % self.pl_interval != 1 or self.pl_weight == 0.0:
            return None

        batch_size = max(runner.batch_size // self.pl_batch_shrink, 1)
        fake_results = self.run_G(runner,
                                  batch_size=batch_size,
                                  sync=sync,
                                  requires_grad=True)
        pl_penalty = self.compute_pl_penalty(images=fake_results['image'],
                                             latents=fake_results['wp'])
        runner.running_stats.update({'Loss/Path Length Penalty': pl_penalty})
        pl_penalty = pl_penalty * self.pl_weight * self.pl_interval

        return (fake_results['image'][:, 0, 0, 0] * 0 + pl_penalty).mean()