# learn mask only e.g., SUV, trucks, racing cars

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import numpy as np
import random
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix


# ----------------------------------------------------------------------------
class Loss:
    def accumulate_gradients(
            self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):  # to be overridden by subclass
        raise NotImplementedError()


# ----------------------------------------------------------------------------
# Regulrarization loss for dmtet
def sdf_reg_loss_batch(sdf, all_edges):
    sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
    mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
    sdf_f1x6x2 = sdf_f1x6x2[mask]
    sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
        sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
               torch.nn.functional.binary_cross_entropy_with_logits(
                   sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
    return sdf_diff


class StyleGAN2Loss(Loss):
    def __init__(
            self, device, G, D, G_source, D_source, r1_gamma=10, style_mixing_prob=0, pl_weight=0,
            gamma_mask=10, ):
        super().__init__()
        self.device = device
        self.G = G
        self.D = D
        self.G_source = G_source
        self.D_source = D_source
        self.r1_gamma = r1_gamma
        self.style_mixing_prob = style_mixing_prob
        self.pl_weight = pl_weight
        self.gamma_mask = gamma_mask
        self.learn_mask = True
        self.learn_rgb = True

    def run_G_source(
            self, z, c, update_emas=False, return_shape=False, return_feats=False,
    ):
        # Step 1: Map the sampled z code to w-space
        ws = self.G_source.mapping(z, c, update_emas=update_emas)
        geo_z = torch.randn_like(z)
        ws_geo = self.G_source.mapping_geo(
            geo_z, c,
            update_emas=update_emas)

        # Step 2: Apply style mixing to the latent code
        if self.style_mixing_prob > 0:
            with torch.autograd.profiler.record_function('style_mixing'):
                cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                cutoff = torch.where(
                    torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff,
                    torch.full_like(cutoff, ws.shape[1]))
                ws[:, cutoff:] = self.G_source.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]

                cutoff = torch.empty([], dtype=torch.int64, device=ws_geo.device).random_(1, ws_geo.shape[1])
                cutoff = torch.where(
                    torch.rand([], device=ws_geo.device) < self.style_mixing_prob, cutoff,
                    torch.full_like(cutoff, ws_geo.shape[1]))
                ws_geo[:, cutoff:] = self.G_source.mapping_geo(torch.randn_like(z), c, update_emas=False)[:, cutoff:]

        # Step 3: Generate rendered image of 3D generated shapes.
        if return_shape and return_feats:
            img, sdf, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, sdf_reg_loss, render_return_value, tex_feature, tex_feat_feat, sdf_feat, deformation_feat, tex_hard_mask = self.G_source.synthesis(
                ws,
                return_shape=return_shape,
                ws_geo=ws_geo,
                return_feats=return_feats,
            )
            return img, sdf, ws, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, ws_geo, sdf_reg_loss, render_return_value, tex_feature, tex_feat_feat, ws, ws_geo, sdf_feat, deformation_feat, tex_hard_mask
        elif return_shape:
            img, sdf, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, sdf_reg_loss, render_return_value = self.G_source.synthesis(
                ws,
                return_shape=return_shape,
                ws_geo=ws_geo,
            )
            return img, sdf, ws, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, ws_geo, sdf_reg_loss, render_return_value, ws, ws_geo
        else:
            img, syn_camera, mask_pyramid, sdf_reg_loss, render_return_value = self.G_source.synthesis(
                ws, return_shape=return_shape,
                ws_geo=ws_geo)
        return img, ws, syn_camera, mask_pyramid, render_return_value, ws, ws_geo
    
    def run_G(
            self, ws, ws_geo, c, camera, update_emas=False, return_shape=False, return_feats=False,
    ):
        if return_shape and return_feats:
            img, sdf, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, sdf_reg_loss, render_return_value, tex_feature, tex_feat_feat, sdf_feat, deformation_feat, tex_hard_mask = self.G.synthesis(
                ws,
                return_shape=return_shape,
                ws_geo=ws_geo,
                return_feats=return_feats,
                camera=camera,
            )
            return img, sdf, ws, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, ws_geo, sdf_reg_loss, render_return_value, tex_feature, tex_feat_feat, sdf_feat, deformation_feat, tex_hard_mask
        elif return_shape:
            img, sdf, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, sdf_reg_loss, render_return_value = self.G.synthesis(
                ws,
                return_shape=return_shape,
                ws_geo=ws_geo,
                camera=camera
            )
            return img, sdf, ws, syn_camera, deformation, v_deformed, mesh_v, mesh_f, mask_pyramid, ws_geo, sdf_reg_loss, render_return_value
        else:
            img, syn_camera, mask_pyramid, sdf_reg_loss, render_return_value = self.G.synthesis(
                ws, return_shape=return_shape,camera=camera,
                ws_geo=ws_geo)
        return img, ws, syn_camera, mask_pyramid, render_return_value

    def run_D(self, img, c, update_emas=False, mask_pyramid=None):
        logits = self.D(img, c, update_emas=update_emas, mask_pyramid=mask_pyramid)
        return logits

    def run_D_source(self, img, c, update_emas=False, mask_pyramid=None):
        logits = self.D_source(img, c, update_emas=update_emas, mask_pyramid=mask_pyramid)
        return logits

    def accumulate_gradients(
            self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        if self.pl_weight == 0:
            phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
        if self.r1_gamma == 0:
            phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)

        sfm = torch.nn.Softmax(dim=1)
        kl_loss = torch.nn.KLDivLoss()
        sim = torch.nn.CosineSimilarity()
        #mse_loss = torch.nn.MSELoss()

        # Gmain: Maximize logits for generated images.
        if phase in ['Gmain', 'Gboth']:
            with torch.autograd.profiler.record_function('Gmain_forward'):

                gen_img_source, gen_sdf_source, _, gen_camera, gen_deformation_source, _, _, _, _, _, _, _, source_tex_feature, source_tex_feature_feat, ws, ws_geo, source_sdf_feat, source_deformation_feat, source_tex_hard_mask = self.run_G_source(
                    gen_z, gen_c, return_shape=True, return_feats=True
                )

                gen_img, gen_sdf, _, _, gen_deformation, _, _, _, mask_pyramid, _, sdf_reg_loss, _, tex_feature, tex_feature_feat, sdf_feat, deformation_feat, tex_hard_mask = self.run_G(
                    ws, ws_geo, gen_c, camera=gen_camera, return_shape=True, return_feats=True
                )
                
                # geometry feature
                # sdf_feat & deformation_feat
                feat_ind = range(64)

                sdf_deformation_feat = torch.cat((sdf_feat,deformation_feat),dim=2)
                source_sdf_deformation_feat = torch.cat((source_sdf_feat,source_deformation_feat),dim=2)

                dist_source = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    anchor_feat = torch.unsqueeze(source_sdf_deformation_feat[pair1,:,feat_ind].reshape(-1),0)
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            target_feat = torch.unsqueeze(source_sdf_deformation_feat[pair2,:,feat_ind].reshape(-1),0)
                            dist_source[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_source = sfm(dist_source)

                dist_target = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    anchor_feat = torch.unsqueeze(sdf_deformation_feat[pair1,:,feat_ind].reshape(-1),0)
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            target_feat = torch.unsqueeze(sdf_deformation_feat[pair2,:,feat_ind].reshape(-1),0)
                            dist_target[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_target = sfm(dist_target)

                rel_loss_sdf_deformation_patch = 2e+4 * kl_loss(torch.log(dist_target), dist_source)

                # texure feature
                feat_ind = range(16)

                dist_source = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            shared_mask = torch.clamp(source_tex_hard_mask[pair1] + source_tex_hard_mask[pair2], 1, 2) - 1
                            anchor_feat = torch.unsqueeze((source_tex_feature[pair1,:,:,feat_ind]*shared_mask).reshape(-1),0)
                            target_feat = torch.unsqueeze((source_tex_feature[pair2,:,:,feat_ind]*shared_mask).reshape(-1),0)
                            dist_source[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_source = sfm(dist_source)

                dist_target = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            shared_mask = torch.clamp(tex_hard_mask[pair1] + tex_hard_mask[pair2], 1, 2) - 1
                            anchor_feat = torch.unsqueeze((tex_feature[pair1,:,:,feat_ind]*shared_mask).reshape(-1),0)
                            target_feat = torch.unsqueeze((tex_feature[pair2,:,:,feat_ind]*shared_mask).reshape(-1),0)
                            dist_target[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_target = sfm(dist_target)

                rel_loss_tex_patch = 5e+3 * kl_loss(torch.log(dist_target), dist_source)

                dist_source = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            share_mask = torch.clamp(gen_img_source[pair1,3] + gen_img_source[pair2,3], 1, 2) - 1
                            anchor_feat = torch.unsqueeze((torch.clamp(torch.round((gen_img_source[pair1,:3] + 1) * (255 / 2)),0,255)*share_mask).reshape(-1),0)
                            target_feat = torch.unsqueeze((torch.clamp(torch.round((gen_img_source[pair2,:3] + 1) * (255 / 2)),0,255)*share_mask).reshape(-1),0)
                            dist_source[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_source = sfm(dist_source)

                dist_target = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            share_mask = torch.clamp(gen_img[pair1,3] + gen_img[pair2,3], 1, 2) - 1
                            anchor_feat = torch.unsqueeze((torch.clamp(torch.round((gen_img[pair1,:3] + 1) * (255 / 2)),0,255)*share_mask).reshape(-1),0)
                            target_feat = torch.unsqueeze((torch.clamp(torch.round((gen_img[pair2,:3] + 1) * (255 / 2)),0,255)*share_mask).reshape(-1),0)
                            dist_target[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_target = sfm(dist_target)

                rel_loss_rgb_patch = 1e+4 * kl_loss(torch.log(dist_target), dist_source)

                dist_source = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    anchor_feat = torch.unsqueeze((2*gen_img_source[pair1,3]-1).reshape(-1),0)
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            target_feat = torch.unsqueeze((2*gen_img_source[pair2,3]-1).reshape(-1),0)
                            dist_source[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_source = sfm(dist_source)

                dist_target = torch.zeros([tex_feature.shape[0], tex_feature.shape[0]-1]).cuda()
                for pair1 in range(tex_feature.shape[0]):
                    anchor_feat = torch.unsqueeze((2*gen_img[pair1,3]-1).reshape(-1),0)
                    tmpc = 0
                    for pair2 in range(tex_feature.shape[0]):
                        if pair1 != pair2:
                            target_feat = torch.unsqueeze((2*gen_img[pair2,3]-1).reshape(-1),0)
                            dist_target[pair1, tmpc] = sim(anchor_feat, target_feat)
                            tmpc += 1
                dist_target = sfm(dist_target)

                rel_loss_mask_patch = 5e+3 * kl_loss(torch.log(dist_target), dist_source)

                camera_condition = None
                if self.G.synthesis.data_camera_mode == 'shapenet_car' or self.G.synthesis.data_camera_mode == 'shapenet_chair' \
                        or self.G.synthesis.data_camera_mode == 'shapenet_motorbike' or self.G.synthesis.data_camera_mode == 'renderpeople' or \
                        self.G.synthesis.data_camera_mode == 'shapenet_plant' or self.G.synthesis.data_camera_mode == 'shapenet_vase' or \
                        self.G.synthesis.data_camera_mode == 'ts_house' or self.G.synthesis.data_camera_mode == 'ts_animal' or \
                        self.G.synthesis.data_camera_mode == 'all_shapenet':
                    camera_condition = torch.cat((gen_camera[-2], gen_camera[-1]), dim=-1)
                else:
                    assert NotImplementedError

                gen_logits = self.run_D(gen_img, camera_condition, mask_pyramid=mask_pyramid)
                _, gen_logits_mask = gen_logits

                loss_Gmain = 0

                if self.learn_mask:
                    training_stats.report('Loss/scores/fake_mask', gen_logits_mask)
                    training_stats.report('Loss/signs/fake_mask', gen_logits_mask.sign())
                    loss_Gmask = torch.nn.functional.softplus(-gen_logits_mask).mean() 
                    training_stats.report('Loss/G/loss_mask', loss_Gmask)
                    loss_Gmain += loss_Gmask
                    training_stats.report('Loss/G/loss', loss_Gmain)

                # Regularization loss for sdf predictions
                sdf_reg_loss_entropy = sdf_reg_loss_batch(gen_sdf, self.G.synthesis.dmtet_geometry.all_edges).mean() * 0.01
                training_stats.report('Loss/G/sdf_reg', sdf_reg_loss_entropy)
                loss_Gmain += sdf_reg_loss_entropy
                training_stats.report('Loss/G/sdf_reg_abs', sdf_reg_loss)
                loss_Gmain += sdf_reg_loss.mean()

            # geometry
            for parameter in self.G.mapping.parameters():
                parameter.requires_grad = False

            for parameter in self.G.mapping_geo.parameters():
                parameter.requires_grad = False

            with torch.autograd.profiler.record_function('Gmain_backward'):
                (loss_Gmain.mean().mul(gain)+(rel_loss_tex_patch.mean()+rel_loss_sdf_deformation_patch.mean()+rel_loss_rgb_patch.mean()+rel_loss_mask_patch.mean())).backward()
                   

        # We didn't have Gpl regularization
        #######################################################
        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if phase in ['Dmain', 'Dboth']:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                # First generate the rendered image of generated 3D shapes
                
                _, _, _, gen_camera, _, _, _, _, _, _, _, _, _, _, ws, ws_geo, _, _, _ = self.run_G_source(gen_z, gen_c, return_shape=True, return_feats=True)

                gen_img, _, _, _, _, _, _, _, mask_pyramid, _, _, _, _, _, _, _, _ = self.run_G(ws, ws_geo, gen_c, camera=gen_camera, return_shape=True, return_feats=True)

                if self.G.synthesis.data_camera_mode == 'shapenet_car' or self.G.synthesis.data_camera_mode == 'shapenet_chair' \
                        or self.G.synthesis.data_camera_mode == 'shapenet_motorbike' or self.G.synthesis.data_camera_mode == 'renderpeople' or \
                        self.G.synthesis.data_camera_mode == 'shapenet_plant' or self.G.synthesis.data_camera_mode == 'shapenet_vase' or \
                        self.G.synthesis.data_camera_mode == 'ts_house' or self.G.synthesis.data_camera_mode == 'ts_animal' or \
                        self.G.synthesis.data_camera_mode == 'all_shapenet':
                    camera_condition = torch.cat((gen_camera[-2], gen_camera[-1]), dim=-1)
                else:
                    camera_condition = None

                gen_logits = self.run_D(
                    gen_img, camera_condition, update_emas=True, mask_pyramid=mask_pyramid)

                _, gen_logits_mask = gen_logits

                if self.learn_mask:
                    training_stats.report('Loss/scores/fake_mask', gen_logits_mask)
                    training_stats.report('Loss/signs/fake_mask', gen_logits_mask.sign())
                    loss_Dgen_mask = torch.nn.functional.softplus(gen_logits_mask).mean()  # -log(1 - sigmoid(gen_logits))
                    training_stats.report('Loss/D/loss_gen_mask', loss_Dgen_mask)
                    loss_Dgen += loss_Dgen_mask

            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if phase in ['Dmain', 'Dreg', 'Dboth']:
            name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                # Optimize for the real image
                real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
                
                real_logits = self.run_D(real_img_tmp, real_c, )
                _, real_logits_mask = real_logits

                real_logits_source = self.run_D_source(real_img_tmp, real_c, )
                real_logits_source, _ = real_logits_source

                training_stats.report('Loss/scores/real_mask', real_logits_mask)
                training_stats.report('Loss/signs/real_mask', real_logits_mask.sign())

                loss_Dreal = 0
                if phase in ['Dmain', 'Dboth']:
                    if self.learn_mask:
                        loss_Dreal_mask = torch.nn.functional.softplus(-real_logits_mask).mean()  # -log(sigmoid(real_logits))
                        training_stats.report('Loss/D/loss_real_mask', loss_Dreal_mask)
                        loss_Dreal += loss_Dreal_mask
                        training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                # Compute R1 regularization for discriminator
                if phase in ['Dreg', 'Dboth']:
                    # Compute R1 regularization for discriminator of RGB image
                    if self.learn_rgb:
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads = torch.autograd.grad(
                                outputs=[real_logits_source.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]

                        r1_penalty = r1_grads.square().sum([1, 2, 3])
                        loss_Dr1 = r1_penalty.mean() * (self.r1_gamma / 2)
                        training_stats.report('Loss/r1_penalty', r1_penalty)
                        training_stats.report('Loss/D/reg', loss_Dr1)
                    # Compute R1 regularization for discriminator of Mask image
                    if self.learn_mask:
                        with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                            r1_grads_mask = \
                                torch.autograd.grad(
                                    outputs=[real_logits_mask.sum()], inputs=[real_img_tmp], create_graph=True,
                                    only_inputs=True)[0]

                        r1_penalty_mask = r1_grads_mask.square().sum([1, 2, 3])
                        loss_Dr1_mask = r1_penalty_mask.mean() * (self.gamma_mask / 2)
                        training_stats.report('Loss/r1_penalty_mask', r1_penalty_mask)
                        training_stats.report('Loss/D/reg_mask', loss_Dr1_mask)
                        loss_Dr1 += loss_Dr1_mask
            
            with torch.autograd.profiler.record_function(name + '_backward'):
                (loss_Dreal + loss_Dr1).mean().mul(gain).backward()
