import os
from lightning.pytorch.core.optimizer import LightningOptimizer
import numpy as np
from PIL import Image
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from einops import rearrange, repeat

import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import torchvision 
from torch import nn
import lightning
from lightning import LightningModule
import wandb
from lightning_utilities.core.rank_zero import rank_zero_only
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import EulerDiscreteScheduler, UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder, StableVideoDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
from src.models.ema import LitEma
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from ..base import BaseSystem
from safetensors import safe_open
from torch.cuda.amp import autocast
from torch import inf
from src.utils import RankedLogger
from src.models.unet.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from src.models.unet.models.position_map import get_position_map_from_depth
log = RankedLogger(__name__, rank_zero_only=True)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_position_map(depth, cam2world_matrix, intrinsics, resolution, sensor_size=1):
    """
    Create batches of rays and return origins and directions.

    cam2world_matrix: (N, 4, 4)
    intrinsics: (N, 3, 3)
    resolution: int

    ray_origins: (N, M, 3)
    ray_dirs: (N, M, 3)
    """
    b, f, _, h, w = depth.shape
    depth = rearrange(depth, "b f 1 h w -> b f h w")
    fx = intrinsics[..., 0, 0].unsqueeze(-1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1)
    sk = intrinsics[..., 0, 1].unsqueeze(-1)
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) 

    uv = uv * (1.0 * sensor_size / resolution) 


    uv = repeat(uv, "c w h -> b f h w c", b=b, f=f)
    x_cam = uv[..., 0]
    y_cam = uv[..., 1]
    z_cam = depth

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
    ).to(cam2world_matrix.dtype) 

    world_rel_points = torch.einsum(
        "b f i j, b f h w j -> b f i h w", cam2world_matrix, cam_rel_points
    )[..., :3, :, :]

    return world_rel_points

def get_grad_norm(
        parameters, norm_type: float = 2.0) -> torch.Tensor:
    r"""
    Copy from torch.nn.utils.clip_grad_norm_

    Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device
    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
    return total_norm

def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
    r"""
    Copy from torch.nn.utils.clip_grad_norm_

    Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device

    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
  

    if clip_grad:
        if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
            raise RuntimeError(
                f'The total norm of order {norm_type} for gradients from '
                '`parameters` is non-finite, so it cannot be clipped. To disable '
                'this error and scale the gradients by the non-finite norm anyway, '
                'set `error_if_nonfinite=False`')
        clip_coef = max_norm / (total_norm + 1e-6)
        # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
        # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
        # when the gradients do not reside in CPU memory.
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        for g in grads:
            g.detach().mul_(clip_coef_clamped.to(g.device))

    return total_norm

def _get_add_time_ids(
    unet,
    fps,
    motion_bucket_id,
    noise_aug_strength,
    dtype,
    batch_size,
):
    add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
    passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
    expected_add_embed_dim = unet.add_embedding.linear_1.in_features

    if expected_add_embed_dim != passed_add_embed_dim:
        raise ValueError(
            f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
        )
    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
    add_time_ids = add_time_ids.repeat(batch_size, 1)
    return add_time_ids

def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    """Draws samples from an lognormal distribution."""
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return torch.distributions.Normal(loc, scale).icdf(u).exp()


class SVDSystem(BaseSystem):
    def __init__(
        self,
        lr: float,
        mv_model: torch.nn.Module,
        recon_model: torch.nn.Module,
        base_model_id: str = "stabilityai/stable-video-diffusion-img2vid",
        variant: str = "fp16",
        cfg: float = 0.1, 
        report_to: str = "wandb",
        ema_decay_rate: float = 0.9999,
        compile: bool = False,
        use_ema: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["mv_model", "recon_model"])

        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(base_model_id, subfolder="pipeline", variant=variant)
        # self.pipeline.enable_model_cpu_offload()

        self.scheduler = self.pipeline.scheduler
        self.image_encoder = self.pipeline.image_encoder
        self.vae = self.pipeline.vae
        self.feature_extractor = self.pipeline.feature_extractor
        self.image_processor = self.pipeline.image_processor
        self.image_encoder.requires_grad_(False)
        self.vae.requires_grad_(False)


        self.pipeline.unet.requires_grad_(True)
        unet = UNetSpatioTemporalConditionModel.from_pretrained(base_model_id, subfolder="unet", variant=variant)

        self.mv_model = mv_model(unet)
        self.mv_model.set_use_memory_efficient_attention_xformers(True)
        self.mv_model.set_gradient_checkpointing(True)
        self.mv_model.train()
        self.recon_model = recon_model

        model_path = "pretrain/LGM/model_fp16.safetensors"        
        tensors = {}

        with safe_open(model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                tensors[key] = f.get_tensor(key)
                # if key == "unet.conv_in.weight":
                #     rgb_weight = tensors[key][:, :3, :, :]
                #     ray_weight = tensors[key][:, 3:, :, :]
                #     shape_list = list(tensors[key].shape)
                #     shape_list[1] = 4 + 6
                #     new_weight = torch.zeros(shape_list)
                #     new_weight[:, 4:, :, :] = ray_weight
                #     tensors[key] = new_weight

        missing_keys, unexpected_keys = self.recon_model.load_state_dict(tensors, strict=False)


        # metric objects for calculating and averaging accuracy across batches
        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim =StructuralSimilarityIndexMeasure() 
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) 

        self.use_ema = use_ema
        if use_ema:
            self.model_ema = LitEma(self.mv_model, decay=ema_decay_rate)

        self.trainable_parameters = [
            (self.mv_model.parameters(), 1.0),
            (self.recon_model.parameters(), 1.0),
        ]

        self.num_inference_steps = 25
        self.min_guidance_scale = 1.0
        self.max_guidance_scale = 3.0
        self.conditioning_dropout_prob = cfg

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def setup(self, stage: str) -> None:
        super().setup(stage)
        if self.hparams.compile and stage == "fit":
            self.mv_model = torch.compile(self.mv_model)
        
        self.log_image = None
        if isinstance(self.logger, lightning.pytorch.loggers.TensorBoardLogger):
            self.log_image = self.tensorboard_log_image
        elif isinstance(self.logger, lightning.pytorch.loggers.WandbLogger):
            self.log_image = self.wandb_log_image
            self.logger.watch(self.mv_model, log_graph=False)


    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers for training."""
        param_groups = []
        for params, lr_scale in self.trainable_parameters:
            param_groups.append({"params": params, "lr": self.hparams.lr * lr_scale})

        optimizer = torch.optim.AdamW(param_groups)
        return optimizer

    def forward(self, latents, timestep, encoder_hidden_states, added_time_ids, cond, cameras=None):
        return self.mv_model(latents, timestep, encoder_hidden_states, added_time_ids, cond, cameras=cameras)
    
    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def training_step(self, batch, batch_idx):

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w #  [-1, 1]
        bsz, m, c, h, w = diffusion_images.shape
        dtype = diffusion_images.dtype
        latents = self.tensor_to_vae_latent(diffusion_images, self.vae)

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)

        cond_sigmas = rand_log_normal(shape=[bsz, 1, 1, 1], loc=-3.0, scale=0.5).to(latents.device)
        noise_aug_strength = cond_sigmas[0] # TODO: support batch > 1

        condition_image = torch.randn_like(condition_image) * cond_sigmas + condition_image

        conditional_latents = self.tensor_to_vae_latent(condition_image, self.vae)
        conditional_latents = conditional_latents / self.vae.config.scaling_factor

        # Sample a random timestep for each image
        # P_mean=0.7 P_std=1.6
        sigmas = rand_log_normal(shape=[bsz, 1, 1, 1, 1], loc=1.0, scale=1.6).to(latents.device)
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = latents + noise * sigmas
        timesteps = torch.Tensor(
                    [0.25 * sigma.log() for sigma in sigmas]).to(latents.device)

        inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)

        # Get the text embedding for conditioning.
        encoder_hidden_states = self._encode_cond(condition_image)

        added_time_ids = _get_add_time_ids(
            self.mv_model.unet,
            7-1, # fixed
            127, # motion_bucket_id = 127, fixed
            noise_aug_strength, # noise_aug_strength == cond_sigmas
            encoder_hidden_states.dtype,
            bsz,
        )
        added_time_ids = added_time_ids.to(latents.device)

        # Conditioning dropout to support classifier-free guidance during inference. For more details
        # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
        if self.conditioning_dropout_prob is not None:
            random_p = torch.rand(bsz, device=latents.device)
            # Sample masks for the edit prompts.
            prompt_mask = random_p < 2 * self.conditioning_dropout_prob
            prompt_mask = prompt_mask.reshape(bsz, 1, 1)
            # Final text conditioning.
            null_conditioning = torch.zeros_like(encoder_hidden_states)
            encoder_hidden_states = torch.where(
                prompt_mask, null_conditioning, encoder_hidden_states)
            # Sample masks for the original images.
            image_mask_dtype = conditional_latents.dtype
            image_mask = 1 - (
                (random_p >= self.conditioning_dropout_prob).to(
                    image_mask_dtype)
                * (random_p < 3 * self.conditioning_dropout_prob).to(image_mask_dtype)
            )
            image_mask = image_mask.reshape(bsz, 1, 1, 1)
            # Final image conditioning.
            conditional_latents = image_mask * conditional_latents

        # Concatenate the `conditional_latents` with the `noisy_latents`.
        conditional_latents = conditional_latents.unsqueeze(1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
        inp_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2)

        # check https://arxiv.org/abs/2206.00364(the EDM-framework) for more details.
        target = latents

        c_out = -sigmas / ((sigmas**2 + 1)**0.5)
        c_skip = 1 / (sigmas**2 + 1)

        weighing = (1 + sigmas ** 2) * (sigmas**-2.0)

        model_pred_0 = self.forward(
            inp_noisy_latents.to(dtype), timesteps, encoder_hidden_states, added_time_ids=added_time_ids,
            cond = None,
            cameras={
                "intrinsics": batch["intrinsics"],
                "extrinsics": batch["c2w"]
            }
        ) # add depth

        pred_x0_0 = model_pred_0 * c_out + c_skip * noisy_latents

        loss = torch.mean((weighing.float()*(pred_x0_0.float() - target.float()) ** 2))

        with torch.no_grad():
            pred_images = self.pipeline.decode_latents(pred_x0_0.to(dtype), num_frames = diffusion_images.shape[1])

        pred_images = rearrange(pred_images, "b c m h w -> b m c h w").to(dtype)
        render_x1 = self.recon_model(pred_images, batch, timesteps)
        recon_images = diffusion_images * 0.5 + 0.5
        render_x0 = self.recon_model(diffusion_images, batch, 0)

        recon_loss0 = torch.abs(render_x0['images_pred'] - recon_images).mean()
        recon_loss1 = torch.abs(render_x1['images_pred'] - recon_images).mean()

        with autocast(dtype=torch.float32):
            recon_loss0 += self.lpips(
                render_x0['images_pred'].reshape(-1, 3, h, w).float(),
                recon_images.reshape(-1, 3, h, w).float()
            )
            recon_loss1 += self.lpips(
                render_x1['images_pred'].reshape(-1, 3, h, w).float(),
                recon_images.reshape(-1, 3, h, w).float()
            )
        recon_loss = (recon_loss0 + recon_loss1) / 2

        # if torch.rand(1) > 0.5:
        #     # cond = torch.concat([render_x1['images_pred'], render_x1['depths_pred']], axis=2).detach()
        #     cond = render_x1['images_pred']
        #     # depth = render_x1['depths_pred']
        #     # position_map = get_position_map(depth=depth, cam2world_matrix=batch['c2w'], intrinsics=batch['intrinsics'], resolution=depth.shape[-1])
        #     # cond = position_map
        #     model_pred_cond = self.forward(inp_noisy_latents.to(dtype), timesteps, encoder_hidden_states, added_time_ids=added_time_ids, cond=cond, cameras={
        #         "intrinsics": batch["intrinsics"],
        #         "extrinsics": batch["c2w"]
        #     })
        #     # TODO: add render_x1
        #     pred_x0_1 = model_pred_cond * c_out + c_skip * noisy_latents
        #     loss = torch.mean((weighing.float()*(pred_x0_1.float() - target.float()) ** 2))

        self.log("train_loss", loss, prog_bar=True)
        self.log("recon_loss", recon_loss,  prog_bar=True)
        self.log("recon_loss0", recon_loss0,  prog_bar=True)
        self.log("recon_loss1", recon_loss1,  prog_bar=True)
        # print("GPU memory:", torch.cuda.memory_allocated() / 1024 ** 3)
        return loss + recon_loss

    def inference_step(self, batch, batch_idx, dataloader_idx=0, stage = "val"):
        with self.ema_scope():
            images_pred, cond = self._generate_images(batch) # image in [0, 1] 1 x 8 x 3 x 512 x 512
        images_gt = batch['diffusion_images'] * 0.5 + 0.5
        image_fp = self._save_image(images_pred, images_gt, cond, batch["prompt"], f"{batch_idx}_{self.global_rank}", stage=stage)

        images_gt = rearrange(images_gt, "b m c h w -> (b m) c h w")
        images_pred = rearrange(images_pred, "b m c h w -> (b m) c h w")
        cond = rearrange(cond, "b m c h w -> (b m) c h w")

        with autocast(dtype=torch.float32):
            psnr = self.psnr(images_gt.float(), images_pred.float())
            ssim = self.ssim(images_gt.float(), images_pred.float())
            lpips = self.lpips(images_gt.float(), images_pred.float())

            cond_psnr = self.psnr(images_gt.float(), cond.float())
            cond_ssim = self.ssim(images_gt.float(), cond.float())
            cond_lpips = self.lpips(images_gt.float(), cond.float())

        self.log(f"{stage}_psnr", psnr, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_ssim", ssim, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_lpips", lpips, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)

        self.log(f"{stage}_cond_psnr", cond_psnr, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_cond_ssim", cond_ssim, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f"{stage}_cond_lpips", cond_lpips, on_step=False, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True)
        # save cond
        return image_fp



    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "val")

    def test_step(self, batch, batch_idx, dataloader_idx):
        return self.inference_step(batch, batch_idx, dataloader_idx, stage = "test")

    @torch.no_grad()
    def _encode_cond(self, image, do_classifier_free_guidance=False):
        device, dtype = image.device, image.dtype
        image = image.to(torch.float32)
        image = _resize_with_antialiasing(image, (224, 224))
        image = (image + 1.0) / 2.0
        # Normalize the image with for CLIP input
        image = self.feature_extractor(
            images=image,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        ).pixel_values
        
        image = image.to(device).to(dtype)
        image_embeddings = self.image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1) # b x 1 x 768

        if do_classifier_free_guidance:
            negative_image_embeddings = torch.zeros_like(image_embeddings)
            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
        return image_embeddings
        
    @torch.no_grad()
    def tensor_to_vae_latent(self, t, vae, needs_upcasting = True, micro_bs = 1):
        ori_shape_len = len(t.shape)
        model_dtype = next(vae.parameters()).dtype
        if ori_shape_len == 4:
            t = t.unsqueeze(1)
        video_length = t.shape[1]
        t = rearrange(t, "b f c h w -> (b f) c h w")
        if needs_upcasting:
            vae.to(dtype=torch.float32)
            t = t.to(torch.float32)
        # latents = vae.encode(t).latent_dist.sample()
        chunk_outs = []
        t_list = t.chunk(micro_bs, dim=0)
        for t_chunk in t_list:
            chunk_outs.append(vae.encode(t_chunk).latent_dist.sample())
        latents = torch.cat(chunk_outs, dim=0)

        if needs_upcasting:
            vae.to(dtype=model_dtype)
            latents = latents.to(model_dtype)
        latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
        if ori_shape_len == 4:
            latents = latents.squeeze(1)
        latents = latents * vae.config.scaling_factor
        return latents
    
    @torch.no_grad()
    def _generate_images(self, batch):
        do_classifier_free_guidance = self.max_guidance_scale > 1.0

        condition_image = batch["condition_image"]  # cond image b x c x h x w
        diffusion_images = batch["diffusion_images"]  # b x m x c x h x w

        (b, m, c, h, w), device = diffusion_images.shape, diffusion_images.device
        # num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
        num_frames = m
        dtype = diffusion_images.dtype
        self.scheduler.set_timesteps(self.num_inference_steps, device=device)

        timesteps = self.scheduler.timesteps
        image_embeddings = self._encode_cond(condition_image, do_classifier_free_guidance)
        # print(image_embeddings)

        # noise = randn_tensor(condition_image.shape, device=condition_image.device)
        condition_image = condition_image + torch.randn_like(condition_image) * 0.02
        cond_image_latent  = self.tensor_to_vae_latent(condition_image, self.vae)
        cond_image_latent = cond_image_latent / self.vae.config.scaling_factor
        
        cond_image_latent = repeat(cond_image_latent, "b c h w -> b f c h w", f=num_frames)
        if do_classifier_free_guidance:
            cond_image_latent = torch.cat([torch.zeros_like(cond_image_latent), cond_image_latent])

        added_time_ids = _get_add_time_ids(
            self.mv_model.unet,
            7-1,
            127,
            0.02,
            dtype,
            b,
        ).to(device)

        latents = torch.randn((b, num_frames, 4, h // 8, w // 8), device=device, dtype=dtype) * self.scheduler.init_noise_sigma

        guidance_scale = torch.linspace(self.min_guidance_scale, self.max_guidance_scale, num_frames).unsqueeze(0)
        guidance_scale = guidance_scale.to(device)
        guidance_scale = rearrange(guidance_scale, "b m -> b m 1 1 1")

        self._num_timesteps = len(timesteps)
        added_time_ids = torch.cat([added_time_ids] * 2) if do_classifier_free_guidance else added_time_ids
    
        # cond = torch.zeros_like(diffusion_images)
        # cond = torch.zeros((b, m, c, h, w), device=device, dtype=dtype)
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # Concatenate image_latents over channels dimention
            latent_model_input = torch.cat([latent_model_input, cond_image_latent], dim=2)

            # predict the noise residual
            # cond = torch.cat([cond] * 2) if do_classifier_free_guidance else cond
        
            noise_pred = self.forward(
                latent_model_input.to(dtype),
                t,
                encoder_hidden_states=image_embeddings,
                added_time_ids=added_time_ids,
                cond=None,
                cameras={
                    "intrinsics": batch["intrinsics"],
                    "extrinsics": batch["c2w"]
                } if not do_classifier_free_guidance else {
                    "intrinsics": torch.cat([batch["intrinsics"]] * 2),
                    "extrinsics": torch.cat([batch["c2w"]] * 2)
                }
            )

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            output = self.scheduler.step(noise_pred, t, latents)
            pred_x0_latent = output.pred_original_sample.to(dtype)
            pred_x0 = self.pipeline.decode_latents(pred_x0_latent, num_frames = m)
            pred_x0 = rearrange(pred_x0, "b c m h w -> b m c h w").to(dtype)

            recon_results = self.recon_model(pred_x0, batch, t)
            # cond = torch.cat([recon_results['images_pred'], recon_results['depths_pred']], axis=2).detach().to(dtype)
            # depth = recon_results['depths_pred']
            # position_map = get_position_map(depth=depth, cam2world_matrix=batch['c2w'], intrinsics=batch['intrinsics'], resolution=depth.shape[-1])
            # visualize position_map
            # cond = recon_results['images_pred'].detach().to(dtype)
            # cond = position_map.detach().to(dtype)
            latents = output.prev_sample
            # compute the previous noisy sample x_t -> x_t-1
        cond = recon_results['images_pred'].to(dtype)
        latents = latents.to(dtype)
        frames = self.pipeline.decode_latents(latents, num_frames = m) # b c m h w
        frames = rearrange(frames, "b c m h w -> (b m) c h w")
        images_pred = self.image_processor.postprocess(frames, output_type="pt")
        images_pred = rearrange(images_pred, "(b m) c h w -> b m c h w", b=b, m=m)
        return images_pred, cond[:, :, :3]

    @torch.no_grad()
    @rank_zero_only
    def _save_image(self, images_pred, images, cond, prompt, batch_idx, stage="validation"):
        save_dir = self.save_dir
        if self.log_image is not None:
            _images = rearrange(images, "b m c h w -> 1 c (b h) (m w)")
            _cond = rearrange(cond, "b m c h w -> 1 c (b h) (m w)")
            _images_pred = rearrange(images_pred, "b m c h w ->1 c (b h) (m w)")
            _full_image = torch.concat([_images, _cond, _images_pred], axis=2) 
            grid = torchvision.utils.make_grid(_full_image, nrow=2)
            self.log_image(
                tag="{}_images/{}".format(stage, batch_idx),
                image_tensor=grid,
            )
        images = rearrange(images, "b m c h w -> (b h) (m w) c")
        cond = rearrange(cond, "b m c h w -> (b h) (m w) c")
        images_pred = rearrange(images_pred, "b m c h w -> (b h) (m w) c")
        full_image = torch.concat([images, cond, images_pred], axis=0) 
        full_image = (full_image * 255).cpu().numpy().astype(np.uint8)
        with open(
            os.path.join(save_dir, f"{stage}_{self.global_step}_{batch_idx}.txt"), "w"
        ) as f:
            f.write("\n".join(prompt))

        im = Image.fromarray(full_image)
        im_fp = os.path.join(
            save_dir,
            f"{stage}_{self.global_step}_{batch_idx}--{prompt[0].replace(' ', '_').replace('/', '_')}.png",
        )
        im.save(im_fp)

        return im_fp

    def tensorboard_log_image(self, tag: str, image_tensor):
        self.logger.experiment.add_image(
            tag,
            image_tensor,
            self.trainer.global_step,
        )

    def wandb_log_image(self, tag: str, image_tensor):
        image_dict = {
            tag: wandb.Image(image_tensor),
        }
        self.logger.experiment.log(
            image_dict,
            step=self.trainer.global_step,
        )

