import torch
import torch.nn as nn
import torch.nn.functional as F

import lpips
from prompt_tuning.models.unet3d import UNet3D
from diffusers import AutoencoderKLCogVideoX
from taming.modules.losses.vqperceptual import NLayerDiscriminator, weights_init


class PromptVAE(nn.Module):
    def __init__(self, args):
        self.args = args
        
        self.encoder_prompt = UNet3D(in_channels=3, out_channels=3, final_activation="sigmoid")
        self.decoder_prompt = UNet3D(in_channels=3, out_channels=3, final_activation=None)
        self.vae = AutoencoderKLCogVideoX.from_pretrained(args.vae_model_path, subfolder="vae")
        
        self.perceptual_loss = lpips.LPIPS(net="vgg")
        self.discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
        discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
        
    def encode_video(self, video):
        # shape of input video: [B, C, F, H, W]
        # video = video.to(self.vae.device, dtype=self.vae.dtype)
        posterior = self.vae.encode(video).latent_dist
        latent = posterior.sample() * self.vae.config.scaling_factor
        return latent, posterior
    
    def decode_latents(self, latents):
        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
        latents = 1 / self.vae.config.scaling_factor * latents

        frames = self.vae.decode(latents).sample
        return frames
    
    def forward(self, batch):
        scene_flow = batch["scene_flow"]
        pseudo_video = self.encoder_prompt(scene_flow)
        with torch.no_grad():
            latent, posterior = self.encode_video(pseudo_video)
            recon_video = self.decode_latents(latent)
        recon_scene_flow = self.decoder_prompt(recon_video)
        
        rec_loss = F.mse_loss(recon_scene_flow, scene_flow)
        with torch.no_grad():
            p_loss = self.perceptual_loss(recon_scene_flow, scene_flow)
        rec_loss = rec_loss + self.args.perceptual_scale * p_loss
        nll_loss = torch.sum(rec_loss) / rec_loss.shape[0]
        
        kl_loss = posterior.kl()
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
        
        logits_fake = self.discriminator(recon_scene_flow)
        g_loss = -torch.mean(logits_fake)
        last_layer = self.decoder_prompt.conv_out.weight
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retrain_graph=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        disc_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach()
        disc_weight = disc_weight * self.args.disc_scale
        disc_factor = self.args.disc_factor if batch["global_step"] >= self.args.disc_start else 0.0
        
        loss = nll_loss + self.args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss
        
        info = {
            "recon_info":{
                "scene_flow": scene_flow.detach().cpu(),
                "recon_scene_flow": recon_scene_flow.detach().cpu(),
                "pseudo_video": pseudo_video.detach().cpu(),
                "recon_video": recon_video.detach().cpu(),
            },
            "loss_info":{
                "loss": loss.detach().mean().item(),
                "nll_loss": nll_loss.detach().mean().item(),
                "rec_loss": rec_loss.detach().mean().item(),
                "p_loss": p_loss.detach().mean().item(),
                "kl_loss": kl_loss.detach().mean().item(),
                "disc_weight": disc_weight.detach().mean().item(),
                "disc_factor": disc_factor,
                "g_loss": g_loss.detach().mean().item(),
            }
        }
        return loss, info