import os
import numpy as np
from PIL import Image
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
import torchvision as tv
from torch import nn
from lightning import LightningModule
from lightning_utilities.core.rank_zero import rank_zero_only
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import EulerDiscreteScheduler, UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder, StableVideoDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
from src.models.ema import LitEma
from ..base import BaseSystem

def get_add_time_ids(
    unet,
    fps,
    motion_bucket_id,
    noise_aug_strength,
    dtype,
    batch_size
):
    add_time_ids = [fps, motion_bucket_id, noise_aug_strength]

    passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
    expected_add_embed_dim = unet.add_embedding.linear_1.in_features

    if expected_add_embed_dim != passed_add_embed_dim:
        raise ValueError(
            f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
        )

    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
    add_time_ids = add_time_ids.repeat(batch_size, 1)

    return add_time_ids


class SVDSystem(BaseSystem):
    def __init__(
        self,
        lr: float,
        mv_model = None,
        lrm = None,
        base_model_id: str = "stabilityai/stable-video-diffusion-img2vid",
        variant: str = "fp16",
        cfg: float = 0.2, 
        report_to: str = "wandb",
        ema_decay_rate: float = 0.9999,
        compile: bool = False,
        use_ema: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["mv_model"])

        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(base_model_id, subfolder="pipeline", variant=variant)

        self.scheduler = self.pipeline.scheduler
        self.image_encoder = self.pipeline.image_encoder
        self.vae = self.pipeline.vae
        self.feature_extractor = self.pipeline.feature_extractor
        self.image_processor = self.pipeline.image_processor
        self.mv_model = mv_model(self.pipeline.unet)

        # metric objects for calculating and averaging accuracy across batches
        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim =StructuralSimilarityIndexMeasure() 
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) 

        self.use_ema = use_ema
        if use_ema:
            self.model_ema = LitEma(self.model, decay=ema_decay_rate)

        self.trainable_parameters = [
            (self.mv_model.parameters(), 1.0),
        ]

        self.num_videos_per_prompt = 1
        self.fps = 7
        self.noise_aug_strength = 0.02
        self.motion_bucket_id = 127
        self.P_mean=0.7
        self.P_std=1.6
        self.num_inference_steps = 25
        self.min_guidance_scale = 1.0
        self.max_guidance_scale = 3.0

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def setup(self, stage: str) -> None:
        super().setup(stage)

        if self.hparams.compile and stage == "fit":
            self.mv_model = torch.compile(self.mv_model)

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        for params, lr_scale in self.trainable_parameters:
            param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def forward(self, latents, timestep, prompt_embd, meta) -> torch.Tensor:
        return self.mv_model(latents, timestep, prompt_embd, meta)
    
    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)


    def training_step(self, batch, batch_idx):
        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w # normalize by [0.5, 0.5] [-1, 1]
        
        (bs, m, c, h, w), device = diffusion_images.shape, diffusion_images.device

        image_embeddings = self._encode_cond(condition_image) 
        fps = self.fps - 1
        condition_image = self.image_processor.preprocess(condition_image) # 0 ~ 1 -> -1 ~ 1

        noise = randn_tensor(condition_image.shape, device=condition_image.device, dtype=condition_image.dtype)
        condition_image = condition_image + noise * self.noise_aug_strength 
        condition_image_latent = self.upcasting_vae(condition_image)
        condition_image_latent = repeat(condition_image_latent, "b c h w -> b f c h w", f=m)

        added_time_ids = get_add_time_ids(
            self.mv_model.unet,
            fps,
            self.motion_bucket_id,
            self.noise_aug_strength,
            image_embeddings.dtype,
            bs
        )
        added_time_ids = added_time_ids.to(device)
        
        diffusion_images = diffusion_images * 2 - 1
        diffusion_latents = self._encode_image(diffusion_images, self.vae)

        rnd_normal = torch.randn([bs, 1, 1, 1, 1], device=device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        c_skip = 1 / (sigma**2 + 1)
        c_out =  -sigma / (sigma**2 + 1) ** 0.5
        c_in = 1 / (sigma**2 + 1) ** 0.5
        c_noise = sigma.log() / 4
        loss_weight = (sigma ** 2 + 1) / sigma ** 2

        noisy_diffusion_latents = c_in * (diffusion_latents + torch.randn_like(diffusion_latents) * sigma)

        if torch.rand(1) < self.hparams.cfg:
            image_embeddings = torch.zeros_like(image_embeddings)
            condition_image_latent = torch.zeros_like(condition_image_latent)

        input_latents = torch.cat([condition_image_latent, noisy_diffusion_latents], dim=2) # b x f x 2c x h x w
        
        model_pred_zero = self.mv_model(input_latents, c_noise.reshape([bs]), encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids)
        
        pred_x0_zero = c_out * model_pred_zero + c_skip * noisy_diffusion_latents 
        loss_zero = ((pred_x0_zero - diffusion_latents)**2 * loss_weight).mean()
        
        render_x1 = self.lrm(pred_x0_zero)

        recon_loss = self.render_loss(render_x1, batch['recon_images'])
        
        model_pred_cond = self.mv_model(input_latents, c_noise.reshape([bs]), encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, render_x1)

        pred_x0_cond = c_out * model_pred_cond + c_skip * noisy_diffusion_latents 
        loss_condition = ((pred_x0_cond - diffusion_latents)**2 * loss_weight).mean()
        
        self.log("train_loss_zero", loss_zero, prog_bar=True)
        self.log("recon_loss", recon_loss, prog_bar=True)
        self.log("train_loss_cond", loss_cond, prog_bar=True)

        loss = loss_zero + loss + loss_cond

        return loss

    def inference_step(self, batch, batch_idx, dataloader_idx=0, stage = "val"):
        images_pred = self._generate_images(batch) # image in [0, 1]
        images_gt = batch['diffusion_images']
        image_fp = self._save_image(images_pred, images_gt, batch["prompt"], f"{batch_idx}_{self.global_rank}", stage=stage)

        images_gt = rearrange(images_gt, "b m c h w -> (b m) c h w")
        images_pred = rearrange(images_pred, "b m c h w -> (b m) c h w")
        psnr = self.psnr(images_gt, images_pred)
        ssim = self.ssim(images_gt, images_pred)
        lpips = self.lpips(images_gt, images_pred)
        self.log(f"{stage}_psnr", psnr, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        self.log(f"{stage}_ssim", ssim, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        self.log(f"{stage}_lpips", lpips, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        return image_fp

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "val")

    def test_step(self, batch, batch_idx, dataloader_idx):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "test")

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""

        # log images
        if self.hparams.report_to == "wandb":
            self._log_to_wandb("test")

    @torch.no_grad()
    def _encode_cond(self, image, do_classifier_free_guidance=False):
        # image in [0, 1]
        device, dtype = image.device, image.dtype
        image = image * 2.0 - 1.0
        image = _resize_with_antialiasing(image, (224, 224)).to(torch.half)
        image = (image + 1.0) / 2.0
        # Normalize the image with for CLIP input
        image = self.feature_extractor(
            images=image,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        ).pixel_values

        image = image.to(device=device, dtype=dtype)
        image_embeddings = self.image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1) # b x 1 x 768

        if do_classifier_free_guidance:
            negative_image_embeddings = torch.zeros_like(image_embeddings)
            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
        return image_embeddings

    @torch.no_grad()
    def upcasting_vae(self, x_input, vae = None):
        vae = vae if vae is not None else self.vae
        needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
        if needs_upcasting:
            vae.to(dtype=torch.float32)
        z = vae.encode(x_input).latent_dist.sample()  # (bs, 2, 4, 64, 64)
        if needs_upcasting:
            vae.to(dtype=torch.float16)
        return z
    
    @torch.no_grad()
    def _encode_image(self, x_input, vae, scale=True):
        b = x_input.shape[0]
        len_x = len(x_input.shape)
        x_input = x_input.reshape(
            -1, x_input.shape[-3], x_input.shape[-2], x_input.shape[-1]
        )
        z = self.upcasting_vae(x_input, vae)
            
        z = z.reshape(
            b, -1, z.shape[-3], z.shape[-2], z.shape[-1]
        )  # (bs, 2, 4, 64, 64)
        # use the scaling factor from the vae config
        if scale:
            z = z * vae.config.scaling_factor
        z = z.float()
        return z
    
    @torch.no_grad()
    def _generate_images(self, batch):
        do_classifier_free_guidance = self.max_guidance_scale > 1.0

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w

        (b, m, c, h, w), device = diffusion_images.shape, diffusion_images.device
        # num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
        num_frames = m

        self.scheduler.set_timesteps(self.num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        image_embeddings = self._encode_cond(condition_image, do_classifier_free_guidance)
        # print(image_embeddings)

        fps = self.fps - 1

        condition_image = self.image_processor.preprocess(condition_image) # 0 ~ 1 -> -1 ~ 1
        noise = randn_tensor(condition_image.shape, device=condition_image.device, dtype=condition_image.dtype)
        condition_image = condition_image + noise * self.noise_aug_strength 
        condition_image_latent  = self.upcasting_vae(condition_image)
        condition_image_latent = repeat(condition_image_latent, "b c h w -> b f c h w", f=num_frames)

        if do_classifier_free_guidance:
            condition_image_latent = torch.cat([torch.zeros_like(condition_image_latent), condition_image_latent])

        added_time_ids = get_add_time_ids(
            self.mv_model.unet,
            fps,
            self.motion_bucket_id,
            self.noise_aug_strength,
            image_embeddings.dtype,
            b
        )
        added_time_ids = added_time_ids.to(device)

        latents = torch.randn((b, num_frames, 4, h // 8, w // 8), device=device) * self.scheduler.init_noise_sigma

        guidance_scale = torch.linspace(self.min_guidance_scale, self.max_guidance_scale, num_frames).unsqueeze(0)
        guidance_scale = guidance_scale.to(device, latents.dtype)
        guidance_scale = rearrange(guidance_scale, "b m -> b m 1 1 1")

        self._num_timesteps = len(timesteps)
        added_time_ids = torch.cat([added_time_ids] * 2) if do_classifier_free_guidance else added_time_ids

        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # Concatenate image_latents over channels dimention
            latent_model_input = torch.cat([latent_model_input, condition_image_latent], dim=2)
            # predict the noise residual
            noise_pred = self.mv_model(
                latent_model_input,
                t,
                encoder_hidden_states=image_embeddings,
                added_time_ids=added_time_ids
            )

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        frames = self.pipeline.decode_latents(latents, num_frames = m) # b c m h w
        frames = rearrange(frames, "b c m h w -> (b m) c h w")
        images_pred = self.image_processor.postprocess(frames, output_type="pt")
        images_pred = rearrange(images_pred, "(b m) c h w -> b m c h w", b=b, m=m)
        return images_pred

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, images_pred, images, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        images = rearrange(images, "b m c h w -> (b h) (m w) c")
        images_pred = rearrange(images_pred, "b m c h w -> (b h) (m w) c")
        full_image = torch.concat([images, images_pred], axis=0) 
        full_image = (full_image * 255).cpu().numpy().astype(np.uint8)
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))

        im = Image.fromarray(full_image)
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        im.save(im_fp)

        # add image to logger
        if self.hparams.report_to == "tensorboard":
            log_image = torch.tensor(full_image / 255.).permute(2, 0, 1).float().cpu()
            self.logger.experiment.add_image(
                f"{stage}/{self.global_step}_{batch_idx}",
                log_image,
                global_step=self.global_step,
            )
        
        return im_fp

    @torch.no_grad()
    @rank_zero_only
    def _log_to_wandb(self, stage, output_images_fp: Optional[List[Any]] = None):
        import wandb
        
        captions, images = [], []
        if output_images_fp is None:
            # get images which start with {stage}_{self.global_step} from self.save_dir
            for f in os.listdir(self.save_dir):
                if f.startswith(f"{stage}_{self.global_step}") and f.endswith(".png"):
                    captions.append(f)
                    images.append(os.path.join(self.save_dir, f))
        else:
            images = output_images_fp
            captions = [os.basename(fp) for fp in output_images_fp]

        self.logger.experiment.log(
            {
                stage: [
                    wandb.Image(im_fp, caption=caption)
                    for im_fp, caption in zip(images, captions)
                ]
            },
            step=self.global_step,
        )
