# Modified from:
#   taming-transformers:  https://github.com/CompVis/taming-transformers
#   muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F

from modelling.lpips import LPIPS, LPIPSTimm
from modelling.discriminators import PatchGANDiscriminator, StyleGANDiscriminator, PatchGANMaskBitDiscriminator, DinoDiscriminator
from utils.diff_aug import DiffAugment

import torch.distributed as tdist
from einops import rearrange

def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.softplus(-logits_real))
    loss_fake = torch.mean(F.softplus(logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def non_saturating_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real),  logits_real))
    loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def hinge_gen_loss(logit_fake):
    return -torch.mean(logit_fake)


def non_saturating_gen_loss(logit_fake):
    return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake),  logit_fake))


def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


class LeCAM_EMA(object):
    def __init__(self, init=0., decay=0.999):
        self.logits_real_ema = init
        self.logits_fake_ema = init
        self.decay = decay

    def update(self, logits_real, logits_fake):
        self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean(logits_real).item() * (1 - self.decay)
        self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean(logits_fake).item() * (1 - self.decay)


def lecam_reg(real_pred, fake_pred, lecam_ema):
    reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + \
          torch.mean(F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2))
    return reg


class VQLoss(nn.Module):
    def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
                disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
                gen_adv_loss='hinge',
                reconstruction_loss='l2', reconstruction_weight=1.0, codebook_weight=1.0, t_shift_rec_weight=1.0, reconstruction_std_weight=0.0, perceptual_std_weight=0.0,
                perceptual_loss='vgg', perceptual_weight=1.0, perceptual_model='vgg', perceptual_intermediate_loss=False, perceptural_logit_loss=False, perceptual_resize=False, perceptual_dino_variants='depth12_no_train', perceptual_warmup=None,
                lecam_loss_weight=None,
                disc_cr_loss_weight=0.0,
                use_diff_aug=False,
    ):
        super().__init__()
        # discriminator loss
        assert disc_type in ["patchgan", "stylegan", "maskbit", "dino"]
        assert disc_loss in ["hinge", "vanilla", "non-saturating"]
        if disc_type == "patchgan":
            self.discriminator = PatchGANDiscriminator(
                input_nc=disc_in_channels,
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif disc_type == "stylegan":
            self.discriminator = StyleGANDiscriminator(
                input_nc=disc_in_channels,
                image_size=image_size,
            )
        elif disc_type == "maskbit":
            self.discriminator = PatchGANMaskBitDiscriminator(
                input_nc=disc_in_channels,
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif disc_type == "dino":
            self.discriminator = DinoDiscriminator()
        else:
            raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss == "non-saturating":
            self.disc_loss = non_saturating_d_loss
        else:
            raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
        self.discriminator_iter_start = disc_start
        self.disc_weight = disc_weight
        self.disc_adaptive_weight = disc_adaptive_weight

        assert gen_adv_loss in ["hinge", "non-saturating"]
        # gen_adv_loss
        if gen_adv_loss == "hinge":
            self.gen_adv_loss = hinge_gen_loss
        elif gen_adv_loss == "non-saturating":
            self.gen_adv_loss = non_saturating_gen_loss
        else:
            raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
        
        # t_shift loss
        self.t_shift_rec_weight = t_shift_rec_weight

        # perceptual loss
        if perceptual_loss == "vgg":
            self.perceptual_loss = LPIPS().eval()
        elif perceptual_loss == "timm":
            self.perceptual_loss = LPIPSTimm(perceptual_model, perceptual_intermediate_loss, perceptural_logit_loss, perceptual_resize, eval=True, dino_variants=perceptual_dino_variants).eval()
        self.perceptual_weight = perceptual_weight
        self.perceptual_warmup = perceptual_warmup

        # reconstruction loss
        self.reconstruction_loss = reconstruction_loss
        if reconstruction_loss == "l1":
            self.rec_loss = F.l1_loss
        elif reconstruction_loss == "l2":
            self.rec_loss = F.mse_loss
        else:
            raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
        self.rec_weight = reconstruction_weight

        # std_loss
        self.rec_std_weight = reconstruction_std_weight
        self.p_std_weight = perceptual_std_weight

        # codebook loss
        self.codebook_weight = codebook_weight

        self.lecam_loss_weight = lecam_loss_weight
        if self.lecam_loss_weight is not None:
            self.lecam_ema = LeCAM_EMA()

        # from var
        self.use_diff_aug = use_diff_aug
        self.disc_cr_loss_weight = disc_cr_loss_weight

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()

        return d_weight.detach()

    def t_shift_rec_loss(self, reconstructions, shift_reconstructions, t_shift):
        losses = []
        for x, y, t in zip(reconstructions, shift_reconstructions, t_shift):
            loss = self.rec_loss(x[:, t:].contiguous(), y[:, :-t].contiguous())
            losses.append(loss)
        return torch.stack(losses).mean()

    def rec_std_loss(self, inputs, reconstructions):
        rec_loss = self.rec_loss(inputs, reconstructions, reduction='none')  # b c f h w
        rec_loss_per_f = rec_loss.mean(dim=(1, 3, 4))  # b f
        std_loss = rec_loss_per_f.std(dim=1).mean()
        rec_loss = rec_loss.mean()
        return rec_loss, std_loss

    def percept_std_loss(self, inputs, reconstructions):
        b, c, f, h, w = inputs.shape
        inputs = rearrange(inputs, 'b c f h w -> (b f) c h w')
        reconstructions = rearrange(reconstructions, 'b c f h w -> (b f) c h w')
        p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) # (b f) 1 1 1
        p_loss_per_f = p_loss.view(b, f)
        std_loss = p_loss_per_f.std(dim=1).mean()
        p_loss = p_loss.mean()
        return p_loss, std_loss

    def forward(self, extra_loss, inputs, reconstructions, shift_reconstructions, t_shift, optimizer_idx, global_step, last_layer=None):
        # accept video input
        if inputs.ndim == 5:
            inputs_5d, recon_5d = inputs, reconstructions
            inputs = rearrange(inputs, 'b c f h w -> (b f) c h w')
            reconstructions = rearrange(reconstructions, 'b c f h w -> (b f) c h w')
            is_video = True
        else:
            is_video = False

        # generator update
        if optimizer_idx == 0:
            # reconstruction loss & std_loss
            if is_video:
                rec_loss, rec_std_loss = self.rec_std_loss(inputs_5d.contiguous(), recon_5d.contiguous())
            else:
                rec_loss = self.rec_loss(inputs, reconstructions, reduction='mean')
                rec_std_loss = 0

            # perceptual loss
            if is_video:
                p_loss, p_std_loss = self.percept_std_loss(inputs_5d.contiguous(), recon_5d.contiguous())
            else:
                p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()).mean()
                p_std_loss = 0

            if self.perceptual_warmup is None:
                perceptual_weight = self.perceptual_weight
            else:
                perceptual_weight = min(1.0, global_step / self.perceptual_warmup) * self.perceptual_weight

            # t_shift loss
            if shift_reconstructions is not None:
                t_shift_rec_loss = self.t_shift_rec_loss(reconstructions.contiguous(), shift_reconstructions.contiguous(), t_shift)
            else:
                t_shift_rec_loss = 0

            nll_loss = self.rec_weight * rec_loss + perceptual_weight * p_loss + self.t_shift_rec_weight * t_shift_rec_loss

            # discriminator loss
            if self.disc_weight > 0:
                if self.use_diff_aug:
                    reconstructions = DiffAugment(reconstructions.contiguous(), policy='color,translation,cutout_0.2', prob=0.5)

                logits_fake = self.discriminator(reconstructions.contiguous())
                generator_adv_loss = self.gen_adv_loss(logits_fake)

                if self.disc_adaptive_weight:
                    null_loss = self.rec_weight * rec_loss + perceptual_weight * p_loss
                    disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
                else:
                    disc_adaptive_weight = 1

                disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
            else:
                disc_weight, disc_adaptive_weight, generator_adv_loss = 0, 0, 0

            loss = nll_loss + disc_adaptive_weight * disc_weight * generator_adv_loss + \
                self.rec_std_weight * rec_std_loss + self.p_std_weight * p_std_loss
                # codebook_loss[0] + codebook_loss[1] + codebook_loss[2]

            # extra loss
            for _, val in extra_loss.items():
                loss += val

            rec_loss = self.rec_weight * rec_loss
            p_loss = perceptual_weight * p_loss
            t_shift_rec_loss = self.t_shift_rec_weight * t_shift_rec_loss
            generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
            # repa_loss = codebook_loss[4] if len(codebook_loss) > 4 else 0.0
            rec_std_loss = self.rec_std_weight * rec_std_loss 
            p_std_loss = self.p_std_weight * p_std_loss
            loss_info = {
                'loss': loss,
                'rec_loss': rec_loss,
                'perceptual_loss': p_loss,
                't_shift_rec_loss': t_shift_rec_loss,
                # 'vq_loss': codebook_loss[0],
                # 'commit_loss': codebook_loss[1],
                # 'entropy_loss': codebook_loss[2],
                # 'codebook_usage': codebook_loss[3],
                # 'repa_loss': repa_loss,
                'generator_adv_loss': generator_adv_loss,
                'disc_adaptive_weight': disc_adaptive_weight,
                'disc_weight': disc_weight,
                'std_loss': rec_std_loss,
                'p_std_loss': p_std_loss
            }
            loss_info.update(extra_loss)
            return loss_info

        # discriminator update
        if optimizer_idx == 1:
            if self.use_diff_aug:
                logits_real = self.discriminator(DiffAugment(inputs.contiguous().detach(), policy='color,translation,cutout_0.2', prob=0.5))
                logits_fake = self.discriminator(DiffAugment(reconstructions.contiguous().detach(), policy='color,translation,cutout_0.2', prob=0.5))
            else:
                logits_real = self.discriminator(inputs.contiguous().detach())
                logits_fake = self.discriminator(reconstructions.contiguous().detach())

            disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)

            if self.lecam_loss_weight is not None:
                self.lecam_ema.update(logits_real, logits_fake)
                lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema)
                non_saturate_d_loss = self.disc_loss(logits_real, logits_fake)
                d_adversarial_loss = disc_weight * (lecam_loss * self.lecam_loss_weight + non_saturate_d_loss)
            else:
                d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)


            if self.disc_cr_loss_weight:
                logits_real_s = self.discriminator(DiffAugment(inputs.contiguous().detach(), policy='color,translation,cutout_0.5', prob=1.0))
                logits_fake_s = self.discriminator(DiffAugment(reconstructions.contiguous().detach(), policy='color,translation,cutout_0.5', prob=1.0))
                disc_cr_loss_weight = self.disc_cr_loss_weight if global_step >= self.discriminator_iter_start else 0.0
                d_cr = F.mse_loss(torch.cat([logits_real, logits_fake], dim=0), torch.cat([logits_real_s, logits_fake_s])) * disc_cr_loss_weight
                d_adversarial_loss += d_cr

            logits_real = logits_real.detach().mean()
            logits_fake = logits_fake.detach().mean()
            loss_info = {
                'discriminator_adv_loss': d_adversarial_loss,
                'disc_weight': disc_weight,
                'logits_real': logits_real,
                'logits_fake': logits_fake,
            }
            if self.disc_cr_loss_weight:
                loss_info.update({'disc_cr_loss_weight': d_cr})
            return loss_info
