"""
PS-VAE core component code, which contains the model: PSVAE and the loss: PixelSemanticLoss.

Key idea (high-level):
- Stage 1 (S-VAE): learn a compact KL-regularized semantic latent z from a pretrained
  representation encoder with semantic reconstruction loss.
- Stage 2 (PS-VAE): unfreeze  E and add a pixel reconstruction objective, while
  preserving semantics with semantic reconstruction loss.

Due to institutional code release policies, this submission includes only an illustrative example of the PS-VAE training procedure. 
The complete implementation and training pipeline will be released following an internal compliance review.

"""


import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import copy
import math
from einops import rearrange

from taming.modules.losses.vqperceptual import LPIPS, adopt_weight, hinge_d_loss, vanilla_d_loss


from disc import build_discriminator
# You can refer to the discriminator in 
# Unitoker : https://github.com/FoundationVision/UniTok
# RAE: https://github.com/bytetriper/RAE
# which freezes the DINOv2 model and only trains the discriminator head.

from ldm.util import instantiate_from_config

from ldm.modules.diffusionmodules.model import Decoder
# decoder implementation is the same as the one in the LDM & VAVAE, you can refer to 
# https://github.com/CompVis/taming-transformers & https://github.com/hustvl/LightningDiT/tree/main/vavae
from ldm.models.utils import freeze_module, unfreeze_module



from model import get_representation_encoder

from transformers.activations import ACT2FN
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

def nchw_to_n_hw_c(x):
    b, c, h, w = x.shape
    x = x.permute(0, 2, 3, 1)
    x = x.view(b, -1, c)
    return x
def n_hw_c_to_nchw(x):
    b, n, c = x.shape
    h = w = int(math.sqrt(n))
    x = x.view(b, h, w, c)
    x = x.permute(0, 3, 1, 2)
    return x

class MLPconnector(nn.Module):

    def __init__(self, in_dim: int, out_dim: int, hidden_act="gelu_pytorch_tanh"):
        super().__init__()
        self.activation_fn = ACT2FN[hidden_act]
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.fc2 = nn.Linear(out_dim, out_dim)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

class PSVAE(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 embed_dim,
                 semantic_detach_pixel=True,
                 pixel_detach_semantic=True,
                 weight_init=None,
                 image_key="image",
                accumulate_grad_batches=1,
                grad_norm=100,
                retrain_discriminator=False,
                detach_semantic=True,
                num_compress_encoder_layers=3,
                num_compress_decoder_layers=3,
                stage="svae",
                detach_pixel_ratio=0,
                representation_encoder_type="dinov2",
                 **kwargs
                 ):
        super().__init__()
        self.image_key = image_key
        self.semantic_detach_pixel = semantic_detach_pixel
        self.accumulate_grad_batches = accumulate_grad_batches
        self.real_step = 0
        self.retrain_discriminator = retrain_discriminator
        self.detach_semantic = detach_semantic

        self.stage = stage
        assert stage in ["svae", "psvae"]
        self.detach_pixel_ratio = detach_pixel_ratio

    
        self.representation_encoder = get_representation_encoder(type=representation_encoder_type)
        
        freeze_module(self.representation_encoder)
        self.stage = stage
        if stage == "svae":
            freeze_module(self.representation_encoder)
            
        elif stage == "psvae":
            unfreeze_module(self.representation_encoder)
            self.teacher_representation_encoder = copy.deepcopy(self.representation_encoder)
            freeze_module(self.teacher_representation_encoder)

        self.loss = PixelSemanticLoss(**lossconfig)
        self.loss.cuda()
        ddconfig["z_channels"] = embed_dim
        self.pixel_decoder = Decoder(**ddconfig)
        # decoder implementation is the same as the one in the LDM & VAVAE, you can refer to 
        # https://github.com/CompVis/taming-transformers & https://github.com/hustvl/LightningDiT/tree/main/vavae
        assert ddconfig["double_z"]
        unfreeze_module(self.pixel_decoder)
        self.embed_dim = embed_dim


        self.automatic_optimization = False
        self.grad_norm = grad_norm




        share_cache = dict()
        for n, m in self.named_modules():
            m.share_cache = share_cache

        transformer_block = copy.deepcopy(self.representation_encoder.encoder.layer[-1])
        reoresentaion_encoder_dim = self.representation_encoder.hidden_dim
        # compress the feature
        
        self.semantic_encoder = nn.ModuleList([copy.deepcopy(transformer_block) for _ in range(num_compress_encoder_layers)])
        compress_mlp = MLPconnector(reoresentaion_encoder_dim, embed_dim * 2)
        self.semantic_encoder.append(compress_mlp)
        unfreeze_module(self.semantic_encoder)

        decompress_mlp = MLPconnector(embed_dim, reoresentaion_encoder_dim)
        self.semantic_decoder = nn.ModuleList([copy.deepcopy(transformer_block) for _ in range(num_compress_decoder_layers)])
        self.semantic_decoder.insert(0, decompress_mlp)


        unfreeze_module(self.semantic_decoder)
        self.retrain_discriminator = retrain_discriminator
        if weight_init is not None:
            self.init_from_ckpt(weight_init)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        if self.retrain_discriminator:
            for k in keys:
                if "discriminator" in k:
                    print("Deleting discriminator key {} from state_dict.".format(k))
                    del sd[k]
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k)) 
                    del sd[k]
        msg = self.load_state_dict(sd, strict=False)
        print(f"Restored from {path} with msg: {msg}")

    def encode(self, x):
        # Encode input image following the same pipeline as forward
        # 1. Representation encoder encode
        feature = self.representation_encoder(x)
        feature = feature.float()
        
        # 2. Convert to n_hw_c format
        feature = nchw_to_n_hw_c(feature)
        
        # 3. Apply compress layers
        vae_feature = feature
        for layer in self.semantic_encoder:
            vae_feature = layer(vae_feature)
        
        # 4. Convert back to nchw format
        vae_feature = n_hw_c_to_nchw(vae_feature)
        
        # 5. Get posterior distribution
        posterior = DiagonalGaussianDistribution(vae_feature)
        
        return posterior.mode()
    


    def decode(self, z):
        # Decode latent z following the same pipeline as forward
        # 1. Decode to pixel space
        x_rec = self.pixel_decoder(z)
        return x_rec


   
    def forward(self, input):
        if self.stage == "svae":
            with torch.no_grad():
                feature = self.representation_encoder(input)
                teacher_feature = feature
        else:
            assert self.stage == "psvae"
            feature = self.representation_encoder(input)
            feature_before_compress = feature.clone()
            with torch.no_grad():
                teacher_feature = self.teacher_representation_encoder(input)


        feature = feature.float()
        teacher_feature = teacher_feature.float()
        feature = nchw_to_n_hw_c(feature)

        vae_feature = feature
        for layer in self.semantic_encoder:
            vae_feature = layer(vae_feature)
        
        vae_feature = n_hw_c_to_nchw(vae_feature)
        posterior = DiagonalGaussianDistribution(vae_feature)

        z = posterior.sample()

        if self.stage == "svae":
            pixel_z = z.detach() # detach the pixel reconstruction loss to avoid the pixel reconstruction loss to affect the semantic reconstruction 
        elif self.stage == "psvae":
            pixel_z = z.detach() * self.detach_pixel_ratio + z * (1 - self.detach_pixel_ratio)

        pixel_z = self.pixel_decoder(pixel_z)

        rec_feature = nchw_to_n_hw_c(z)

        for layer in self.semantic_decoder:
            rec_feature = layer(rec_feature)
        
        rec_feature = n_hw_c_to_nchw(rec_feature)

        if self.stage == "svae":
            return pixel_z, posterior, rec_feature, teacher_feature
        
        else:
            assert self.stage == "psvae"
            rec_feature = (rec_feature, feature_before_compress)
            teacher_feature = (teacher_feature, teacher_feature)

            return pixel_z, posterior, rec_feature, teacher_feature


    def get_input(self, batch, k):
        x = batch[k]
        if x.ndim == 3:
            x = x.unsqueeze(0)
        # x: B x C x H x W or B x H x W x C, we want B x C x H x W
        if x.shape[-1] == 3 and x.shape[1] != 3:
            # Possibly B x H x W x C
            x = x.permute(0, 3, 1, 2)
        b, c, h, w = x.shape
        if h != 256 or w != 256:
            short, long = (h, w) if h < w else (w, h)
            scale = 256 / short
            new_h = int(round(h * scale))
            new_w = int(round(w * scale))
            x = F.interpolate(x, size=(new_h, new_w), mode='bicubic', align_corners=False)
            # center crop
            top = (new_h - 256) // 2
            left = (new_w - 256) // 2
            x = x[:, :, top:top+256, left:left+256]



        return x.contiguous()

    def configure_gradient_clipping(
        self,  norm_val=1.0, gradient_clip_algorithm="norm", name="global"
        ):
    
        parameters = self.parameters()

        if gradient_clip_algorithm == "norm":
            norm = torch.nn.utils.clip_grad_norm_(parameters, norm_val)
        elif gradient_clip_algorithm == "value":
            norm = torch.nn.utils.clip_grad_value_(parameters, norm_val)


        self.log(f"grad_norm/{name}", norm,  prog_bar=True, logger=True, on_step=True, on_epoch=True)

    def training_step(self, batch, batch_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior, z, aux_feature = self(inputs)
        ae_opt, disc_opt = self.optimizers()
        scheduler_ae = self.lr_schedulers()[0]  # 如果返回多个 schedulers
        scheduler_disc = self.lr_schedulers()[1]
        # if optimizer_idx == 0:
        # train encoder+decoder+logvar
        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
                                        last_layer=self.get_last_layer(), split="train", z=z, aux_feature=aux_feature)

        if self.real_step % self.accumulate_grad_batches == 0:
            if self.real_step == 0:
                ae_opt.zero_grad()
            self.manual_backward(aeloss)
            grad_norm = getattr(self, "grad_norm", 10)
            self.configure_gradient_clipping(norm_val=grad_norm, gradient_clip_algorithm="norm", name="ae")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            scheduler_ae.step()
            ae_opt.step()
            ae_opt.zero_grad()
            
            # reset the discriminator, 
            disc_opt.zero_grad()


        else:
            self.manual_backward(aeloss)

        # if optimizer_idx == 1:
        # train the discriminator
        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")

        # return discloss
        # because after the ae loss, the 
        if (self.real_step - 1) % self.accumulate_grad_batches == 0:
            if (self.real_step - 1) == 0:
                disc_opt.zero_grad()
            self.manual_backward(discloss)

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            disc_opt.step()
            scheduler_disc.step()
            disc_opt.zero_grad()
        else:
            self.manual_backward(discloss)
        
        
        ae_lr = ae_opt.param_groups[0]["lr"]
        self.log("lr/ae", float(f"{ae_lr:.6f}"), prog_bar=True, on_step=True)  # 保留6位小数

        disc_lr = disc_opt.param_groups[0]["lr"]
        self.log("lr/disc", float(f"{disc_lr:.6f}"), prog_bar=True, on_step=True)
        
        self.real_step += 1



    def get_last_layer(self):
        return self.pixel_decoder.conv_out.weight





class PixelSemanticLoss(nn.Module):
    def __init__(self, disc_start, logvar_init=0.0, kl_weight=1e-6,
                 disc_factor=1.0, disc_weight=0.5,
                 perceptual_weight=1.0,
                 disc_loss="hinge", distmat_weight=1.0, cos_weight=1.0, 
                 l2_loss_weight=1.0, semantic_reconstruction_weight=0.1,
                 **kwargs):

        super().__init__()
        _ = kwargs  # allow extra config fields in this minimal example
        assert disc_loss in ["hinge", "vanilla"]
        self.kl_weight = kl_weight
        self.l2_loss_weight = l2_loss_weight
        self.perceptual_loss = LPIPS().eval()
        self.semantic_reconstruction_weight = semantic_reconstruction_weight
        self.perceptual_weight = perceptual_weight
        self.distmat_weight = distmat_weight
        self.cos_weight = cos_weight
        # output log variance
        disc_cfg = {'arch': {'dino_ckpt_path': 'path_to_dino_vit_small_patch8_224.pth', 'ks': 9, 'norm_type': 'bn', 'using_spec_norm': True, 'recipe': 'S_8'}, 'optimizer': {'lr': 0.0002, 'betas': [0.5, 0.9], 'weight_decay': 0.0}, 'scheduler': {'type': 'cosine', 'warmup_epochs': 1, 'decay_end_epoch': 16, 'base_lr': 0.0002, 'final_lr': 2e-05}, 'augment': {'prob': 1.0, 'cutout': 0.0}}
        # Initialize on CPU, PyTorch Lightning will move it to the correct device
        self.discriminator, self.discriminator_aug = build_discriminator(disc_cfg, "cpu")
        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)

                                                 
        self.discriminator_iter_start = disc_start
        self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
        self.disc_factor = disc_factor
        self.discriminator_weight = disc_weight

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight
    


    def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
                global_step, last_layer=None, split="train",
                weights=None, z=None, aux_feature=None):

        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
        if self.perceptual_weight > 0:
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            rec_loss = rec_loss + self.perceptual_weight * p_loss
        nll_loss = rec_loss
        weighted_nll_loss = nll_loss
        if weights is not None:
            weighted_nll_loss = weights*nll_loss

        weighted_nll_loss = torch.mean(weighted_nll_loss)
        nll_loss = torch.mean(nll_loss)
        if hasattr(posteriors, "kl"):
            kl_loss = posteriors.kl(no_sum=True)
            kl_loss = torch.mean(kl_loss)
        else:
            kl_loss = torch.tensor(0.0, device=inputs.device)


        # encdoer parameters update
        if optimizer_idx == 0:
            # This minimal example uses the (Unitoker/RAE-style) DINO discriminator,
            # which is *not* a conditional GAN discriminator.
            feak_recons = self.discriminator_aug.aug(reconstructions.contiguous())
            logits_fake, _ = self.discriminator(feak_recons.contiguous(), None)
            g_loss = -torch.mean(logits_fake)

            if self.disc_factor > 0.0:
                try:
                    d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
                except RuntimeError:
#                    assert not self.training
                    print("RuntimeError in calculate_adaptive_weight")
                    d_weight = torch.tensor(0.0)
            else:
                d_weight = torch.tensor(0.0)

            # semantic reconstruction loss
            if z is not None and aux_feature is not None:
                if isinstance(aux_feature, (tuple, list)):
                    assert len(aux_feature) == len(z)
                    z_list = z
                    aux_feature_list = aux_feature
                else:
                    z_list = [z]
                    aux_feature_list = [aux_feature]
                semantic_reconstruction_loss = 0.0
                log_l2_loss = []
                log_cos_loss = []
                log_distmat_loss = []
                for z, aux_feature in zip(z_list, aux_feature_list):
                    if self.l2_loss_weight > 0:
                        l2_loss = torch.nn.functional.mse_loss(z, aux_feature)
                    else:
                        l2_loss = 0.0
                    z_flat = rearrange(z, 'b c h w -> b c (h w)')
                    aux_feature_flat = rearrange(aux_feature, 'b c h w -> b c (h w)')
                    z_norm = torch.nn.functional.normalize(z_flat, dim=1)
                    aux_feature_norm = torch.nn.functional.normalize(aux_feature_flat, dim=1)
                    z_cos_sim = torch.einsum('bci,bcj->bij', z_norm, z_norm)
                    aux_feature_cos_sim = torch.einsum('bci,bcj->bij', aux_feature_norm, aux_feature_norm)
                    diff = torch.abs(z_cos_sim - aux_feature_cos_sim)
                    distmat_loss = diff.mean()
                    cos_sim_loss = (1 - torch.nn.functional.cosine_similarity(aux_feature, z)).mean()
                    semantic_reconstruction_loss = distmat_loss * self.distmat_weight + cos_sim_loss * self.cos_weight + l2_loss * self.l2_loss_weight + semantic_reconstruction_loss
                    log_l2_loss.append(l2_loss.detach().mean())
                    log_cos_loss.append(cos_sim_loss.detach().mean())
                    log_distmat_loss.append(distmat_loss.detach().mean())
            else:
                semantic_reconstruction_loss = torch.tensor(0.0, device=inputs.device)
                log_l2_loss = []
                log_cos_loss = []
                log_distmat_loss = []

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)

            pixel_reconstruction_loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 
            
            
            loss = pixel_reconstruction_loss + self.semantic_reconstruction_weight * semantic_reconstruction_loss
       

            log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
                   "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
                   "{}/rec_loss".format(split): rec_loss.detach().mean(),
                   "{}/d_weight".format(split): d_weight.detach(),
                   "{}/disc_factor".format(split): torch.tensor(disc_factor),
                   "{}/g_loss".format(split): g_loss.detach().mean()
                   }
            for index, (l2_loss, cos_loss, distmat_loss) in enumerate(zip(log_l2_loss, log_cos_loss, log_distmat_loss)):
                if index == 0:
                    log["{}/l2_loss".format(split)] = l2_loss.detach().mean()
                    log["{}/cos_loss".format(split)] = cos_loss.detach().mean()
                    log["{}/distmat_loss".format(split)] = distmat_loss.detach().mean()
                else:
                    log["{}/l2_loss_{}".format(split, index)] = l2_loss.detach().mean()
                    log["{}/cos_loss_{}".format(split, index)] = cos_loss.detach().mean()
                    log["{}/distmat_loss_{}".format(split, index)] = distmat_loss.detach().mean()

            return loss, log
        # discriminator parameters update
        if optimizer_idx == 1:
            # second pass for discriminator update
            reconstructions = torch.clamp(reconstructions.contiguous().detach(), -1, 1)

            feak_recons = self.discriminator_aug.aug(reconstructions.contiguous())
            feak_inputs = self.discriminator_aug.aug(inputs.contiguous())
            logits_fake, logits_real = self.discriminator(feak_recons.contiguous(), feak_inputs.contiguous())
                
            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)

            log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
                   "{}/logits_real".format(split): logits_real.detach().mean(),
                   "{}/logits_fake".format(split): logits_fake.detach().mean()
                   }
            return d_loss, log