"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig

from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, make_gamma_schedule, extract_into_tensor, noise_like
from dydiff.ddim_dydiff import DDIMSampler


__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def uniform_on_device(r1, r2, shape, device):
    return (r1 - r2) * torch.rand(*shape, device=device) + r2


class DynamicalDPM(pl.LightningModule):
    # classic DDPM with Gaussian diffusion, in image space
    def __init__(self,
                 unet_config,
                 timesteps=1000,
                 beta_schedule="linear",
                 gamma_schedule="linear",
                 loss_type="l2",
                 ckpt_path=None,
                 ignore_keys=[],
                 load_only_unet=False,
                 monitor="val/loss",
                 use_ema=True,
                 first_stage_key="image",
                 first_stage_key_prev="prev",
                 image_size=256,
                 channels=3,
                 log_every_t=100,
                 clip_denoised=True,
                 linear_start=1e-4,
                 linear_end=2e-2,
                 cosine_s=8e-3,
                 given_betas=None,
                 linear_start_gamma=1e-4,
                 linear_end_gamma=2e-2,
                 cosine_s_gamma=8e-3,
                 given_gammas=None,
                 original_elbo_weight=0.,
                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
                 l_simple_weight=1.,
                 conditioning_key=None,
                 parameterization="eps",  # all assuming fixed variance schedules
                 scheduler_config=None,
                 use_positional_encodings=False,
                 learn_logvar=False,
                 logvar_init=0.,
                 make_it_fit=False,
                 ucg_training=None,
                 reset_ema=False,
                 reset_num_ema_updates=False,
                 video_mode=False,  # newly added for video prediction
                 ensemble=False,  # newly added for ensembling with raw diffusion
                 frame_weighting=False,  # newly added for frame weighting, version 1 only
                 new_prev_ema=False,  # newly added for more stable ema calculation, if True, version 3,4
                 rescale_ema=False,  # newly added for rescaling ema, if True, version 4
                 input_length=None, # newly added
                 use_x_ema=False,  # newly added for using x_ema
                 ):
        super().__init__()
        assert parameterization in ["eps", "x0", "v", "v_standard"], 'currently only supporting "eps", "x0, "v" and "v_standard"'

        if ensemble:
            assert parameterization != "v", "Ensemble mode not supported with v parameterization"
        
        self.parameterization = parameterization
        self.ensemble = ensemble
        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
        self.cond_stage_model = None
        self.clip_denoised = clip_denoised
        self.log_every_t = log_every_t
        self.first_stage_key = first_stage_key
        self.first_stage_key_prev = first_stage_key_prev
        self.image_size = image_size  # try conv?
        self.channels = channels
        self.use_positional_encodings = use_positional_encodings
        # print(f"UNet config: {unet_config}")

        if not conditioning_key.startswith('concat-video-mask'):
            raise NotImplementedError
        
        self.model = DiffusionWrapper(unet_config, conditioning_key)
        count_params(self.model, verbose=True)
        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self.model)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.use_scheduler = scheduler_config is not None
        if self.use_scheduler:
            self.scheduler_config = scheduler_config

        self.v_posterior = v_posterior
        self.original_elbo_weight = original_elbo_weight
        self.l_simple_weight = l_simple_weight

        if monitor is not None:
            self.monitor = monitor
        self.make_it_fit = make_it_fit
        if reset_ema: assert exists(ckpt_path)
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
            if reset_ema:
                assert self.use_ema
                print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)
        if reset_num_ema_updates:
            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
            assert self.use_ema
            self.model_ema.reset_num_updates()

        self.register_schedule(given_betas=given_betas, given_gammas=given_gammas,
                               beta_schedule=beta_schedule, gamma_schedule=gamma_schedule, timesteps=timesteps,
                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s,
                               linear_start_gamma=linear_start_gamma, linear_end_gamma=linear_end_gamma, cosine_s_gamma=cosine_s_gamma)

        self.loss_type = loss_type

        self.learn_logvar = learn_logvar
        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
        if self.learn_logvar:
            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
        else:
            self.register_buffer('logvar', logvar)

        self.ucg_training = ucg_training or dict()
        if self.ucg_training:
            self.ucg_prng = np.random.RandomState()
        
        self.video_mode = video_mode
        self.frame_weighting = frame_weighting
        self.new_prev_ema = new_prev_ema
        self.rescale_ema = rescale_ema
        self.input_length = input_length
        self.use_x_ema = use_x_ema

    def register_schedule(self, given_betas=None, given_gammas=None, beta_schedule="linear", gamma_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, linear_start_gamma=1e-4, linear_end_gamma=2e-2, linear_cumprod_s_gamma=8e-3):
        if exists(given_betas):
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                       cosine_s=cosine_s)
        
        if exists(given_gammas):
            gammas = given_gammas
        else:
            gammas = make_gamma_schedule(gamma_schedule, timesteps, linear_start=linear_start_gamma, linear_end=linear_end_gamma,
                                        linear_cumprod_s=linear_cumprod_s_gamma)

        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        thetas = 1. - gammas
        thetas_cumprod = np.cumprod(thetas, axis=0)
        thetas_cumprod_prev = np.append(1., thetas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        self.register_buffer('gammas', to_torch(gammas))
        self.register_buffer('thetas_cumprod', to_torch(thetas_cumprod))
        self.register_buffer('thetas_cumprod_prev', to_torch(thetas_cumprod_prev))

        # calculations for prior q(x_t | x_0) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        # self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        self.register_buffer('sqrt_thetas_cumprod', to_torch(np.sqrt(thetas_cumprod)))
        self.register_buffer('sqrt_one_minus_thetas_cumprod', to_torch(np.sqrt(1. - thetas_cumprod)))

        # self.register_buffer('one_minus_thetas_cumprod', to_torch(1. - thetas_cumprod))
        # self.register_buffer('recip_thetas_cumprod', to_torch(1. / thetas_cumprod))
        # self.register_buffer('recipm1_thetas_cumprod', to_torch(1. / thetas_cumprod - 1))
        # self.register_buffer('recip_thetas_cumprod_sqrt_recip_alphas_cumprod', to_torch(1. / thetas_cumprod * np.sqrt(1. / alphas_cumprod)))
        # self.register_buffer('recip_thetas_cumprod_sqrt_recipm1_alphas_cumprod', to_torch(1. / thetas_cumprod * np.sqrt(1. / alphas_cumprod - 1)))

        # self.register_buffer('prior_variance', to_torch(1. - thetas ** 2 * alphas - (1 - thetas ** 2) * alphas_cumprod))


        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
        #         1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alphas_cumprod_tm1) + alphas_t / beta_t)
        # posterior_variance = (1 - self.v_posterior) * self.prior_variance * (1. - alphas_cumprod_prev) / (
        #     1. - alphas_cumprod) + self.v_posterior * self.prior_variance
        
        # self.register_buffer('posterior_variance', to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        # self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))

        # coefficient of x_0
        # self.register_buffer('posterior_mean_coef1', to_torch(
        #     (1. - thetas ** 2 * alphas - (1 - thetas ** 2) * alphas_cumprod) * 
        #     np.sqrt(alphas_cumprod_prev) * thetas_cumprod_prev / (1. - alphas_cumprod))
        # )
        # # coefficient of x_t
        # self.register_buffer('posterior_mean_coef2', to_torch(
        #     (1. - alphas_cumprod_prev) * np.sqrt(alphas) * thetas / (1. - alphas_cumprod))
        # )
        # # coefficient of x_prev
        # self.register_buffer('posterior_mean_coef3', to_torch(
        #     (1. - thetas ** 2 * alphas - (1 - thetas ** 2) * alphas_cumprod) * 
        #     np.sqrt(alphas_cumprod_prev) * (1. - thetas_cumprod_prev) / (1. - alphas_cumprod) - 
        #     (1. - alphas_cumprod_prev) * np.sqrt(alphas_cumprod) * thetas * (1. - thetas) / (1. - alphas_cumprod))
        # )

        # self.register_buffer('v_standard_to_start_coef_1', to_torch(
        #     1. / (1. - (1. - thetas_cumprod) * alphas_cumprod)
        # ))
        # self.register_buffer('v_standard_to_start_coef_2', to_torch(
        #     (1. - thetas_cumprod) * alphas_cumprod / (1. - (1. - thetas_cumprod) * alphas_cumprod)
        # ))

        # self.register_buffer('v_standard_to_eps_interpolate_coef', to_torch(
        #     self.sqrt_alphas_cumprod * thetas_cumprod
        # ))
        # self.register_buffer('v_standard_to_eps_coef_1', to_torch(
        #     1. / (1. - (1. - thetas_cumprod) * alphas_cumprod)
        # ))
        # self.register_buffer('v_standard_to_eps_coef_2', to_torch(
        #     self.sqrt_alphas_cumprod * self.sqrt_one_minus_alphas_cumprod * (1. - thetas_cumprod) / (1. - (1. - thetas_cumprod) * alphas_cumprod)
        # ))

        # if self.parameterization == "eps":
        #     lvlb_weights = self.betas ** 2 / (
        #             2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        # elif self.parameterization == "x0":
        #     lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        # elif self.parameterization == "v":
        #     lvlb_weights = torch.ones_like(self.betas ** 2 / (
        #             2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
        # elif self.parameterization == "v_standard":
        #     lvlb_weights = torch.ones_like(self.betas ** 2 / (
        #             2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
        # else:
        #     raise NotImplementedError("mu not supported")
        # lvlb_weights[0] = lvlb_weights[1]
        # self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        # assert not torch.isnan(self.lvlb_weights).all()

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    @torch.no_grad()
    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        if self.make_it_fit:
            n_params = len([name for name, _ in
                            itertools.chain(self.named_parameters(),
                                            self.named_buffers())])
            for name, param in tqdm(
                    itertools.chain(self.named_parameters(),
                                    self.named_buffers()),
                    desc="Fitting old weights to new weights",
                    total=n_params
            ):
                if not name in sd:
                    continue
                old_shape = sd[name].shape
                new_shape = param.shape
                assert len(old_shape) == len(new_shape)
                if len(new_shape) > 2:
                    # we only modify first two axes
                    assert new_shape[2:] == old_shape[2:]
                # assumes first axis corresponds to output dim
                if not new_shape == old_shape:
                    new_param = param.clone()
                    old_param = sd[name]
                    if len(new_shape) == 1:
                        for i in range(new_param.shape[0]):
                            new_param[i] = old_param[i % old_shape[0]]
                    elif len(new_shape) >= 2:
                        for i in range(new_param.shape[0]):
                            for j in range(new_param.shape[1]):
                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]

                        n_used_old = torch.ones(old_shape[1])
                        for j in range(new_param.shape[1]):
                            n_used_old[j % old_shape[1]] += 1
                        n_used_new = torch.zeros(new_shape[1])
                        for j in range(new_param.shape[1]):
                            n_used_new[j] = n_used_old[j % old_shape[1]]

                        n_used_new = n_used_new[None, :]
                        while len(n_used_new.shape) < len(new_shape):
                            n_used_new = n_used_new.unsqueeze(-1)
                        new_param /= n_used_new

                    sd[name] = new_param

        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
            sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys:\n {missing}")
        if len(unexpected) > 0:
            print(f"\nUnexpected Keys:\n {unexpected}")
    
    ### RAW USAGES ###
    # def raw_q_mean_variance(self, x_start, t):
    #     """
    #     Get the distribution q(x_t | x_0).
    #     :param x_start: the [N x C x ...] tensor of noiseless inputs.
    #     :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
    #     :return: A tuple (mean, variance, log_variance), all of x_start's shape.
    #     """
    #     mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
    #     variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
    #     log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
    #     return mean, variance, log_variance
    
    def raw_predict_start_from_noise(self, x_t, t, noise):
        return (
            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    # def raw_predict_start_from_z_and_v(self, x_t, t, v):
    #     return (
    #         extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    #     )

    # def raw_predict_eps_from_z_and_v(self, x_t, t, v):
    #     return (
    #         extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
    #     )

    # def raw_q_posterior(self, x_start, x_t, t):
    #     posterior_mean = (
    #         extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
    #         extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
    #     )
    #     posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
    #     posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
    #     return posterior_mean, posterior_variance, posterior_log_variance_clipped
    ### RAW USAGES ###

    # def q_mean_variance(self, x_start, x_prev, t):
    #     """
    #     Get the distribution q(x_t | x_0, x_prev).
    #     :param x_start: the [N x C x ...] tensor of noiseless inputs.
    #     :param x_prev: the [N x C x ...] tensor of previous frame inputs.
    #     :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
    #     :return: A tuple (mean, variance, log_variance), all of x_start's shape.
    #     """
    #     mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * (
    #         extract_into_tensor(self.sqrt_thetas_cumprod, t, x_start.shape) * x_start +
    #         extract_into_tensor(self.sqrt_one_minus_thetas_cumprod, t, x_start.shape) * x_prev
    #     )
    #     variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
    #     log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
    #     return mean, variance, log_variance
    
    def predict_start_from_noise(self, x_t, x_prev, t, noise):
        # x_t = rearrange(x_t, 'b (s c) h w -> b s c h w', c=self.z_channels)
        # x_prev = rearrange(x_prev, 'b (s c) h w -> b s c h w', c=self.z_channels)
        # noise = rearrange(noise, 'b (s c) h w -> b s c h w', c=self.z_channels)
        x_start = self.raw_predict_start_from_noise(x_t, t, noise)
        if self.use_x_ema:
            x_start = self.get_reverse_emas(x_start, extract_into_tensor(self.sqrt_thetas_cumprod, t, x_t[:, :1].shape),
                                            extract_into_tensor(self.sqrt_one_minus_thetas_cumprod, t, x_t[:, :1].shape))
        return x_start

    # def predict_start_from_z_and_v(self, x_t, x_prev, t, v):
    #     x_interpolate = (
    #         extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    #     )
    #     return (
    #         extract_into_tensor(self.recip_thetas_cumprod, t, x_t.shape) * x_interpolate -
    #         extract_into_tensor(self.recipm1_thetas_cumprod, t, x_t.shape) * x_prev
    #     )

    # def predict_start_from_z_and_v_standard(self, x_t, x_prev, t, v):
    #     x_interpolate = (
    #         extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    #     )
    #     return (
    #         extract_into_tensor(self.v_standard_to_start_coef_1, t, x_t.shape) * x_interpolate -
    #         extract_into_tensor(self.v_standard_to_start_coef_2, t, x_t.shape) * x_prev
    #     )

    # def predict_eps_from_z_and_v(self, x_t, x_prev, t, v):
    #     return (
    #         extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
    #     )

    # def predict_eps_from_z_and_v_standard(self, x_t, x_prev, t, v):
    #     eps_interpolate = (
    #         extract_into_tensor(self.v_standard_to_eps_interpolate_coef, t, x_t.shape) * v +
    #         extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
    #     )
    #     return (
    #         extract_into_tensor(self.v_standard_to_eps_coef_1, t, x_t.shape) * eps_interpolate -
    #         extract_into_tensor(self.v_standard_to_eps_coef_2, t, x_t.shape) * x_prev
    #     )
    
    # def q_posterior(self, x_start, x_t, x_prev, t):
    #     posterior_mean = (
    #         extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
    #         extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + 
    #         extract_into_tensor(self.posterior_mean_coef3, t, x_t.shape) * x_prev
    #     )
    #     posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
    #     posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
    #     return posterior_mean, posterior_variance, posterior_log_variance_clipped

    # def p_mean_variance(self, x, x_prev, t, clip_denoised: bool):
    #     model_out = self.model(x, t)

    #     if self.ensemble:
    #         x, x_raw = x.chunk(2, dim=1)  # channel-wise split
    #         if self.parameterization == "eps":
    #             x_recon_raw = self.raw_predict_start_from_noise(x_raw, t, noise=model_out)
    #         elif self.parameterization == "x0":
    #             x_recon_raw = model_out
    #         elif self.parameterization == "v_standard":
    #             x_recon_raw = self.predict_start_from_z_and_v_standard(x, x_prev, t, model_out)
    #         if clip_denoised:
    #             x_recon_raw.clamp_(-1., 1.)
    #         model_mean_raw, posterior_variance_raw, posterior_log_variance_raw = self.q_posterior(x_start=x_recon_raw, x_t=x_raw, x_prev=x_prev, t=t)
        
    #     if self.parameterization == "eps":
    #         x_recon = self.predict_start_from_noise(x, x_prev, t=t, noise=model_out)
    #     elif self.parameterization == "x0":
    #         x_recon = model_out
    #     elif self.parameterization == "v":
    #         x_recon = self.predict_start_from_z_and_v(x, x_prev, t, model_out)
    #     elif self.parameterization == "v_standard":
    #         x_recon = self.predict_start_from_z_and_v_standard(x, x_prev, t, model_out)
    #     else:
    #         raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
    #     if clip_denoised:
    #         x_recon.clamp_(-1., 1.)

    #     model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, x_prev=x_prev, t=t)

    #     if self.ensemble:
    #         model_mean = torch.cat([model_mean, model_mean_raw], dim=1)
    #         posterior_variance = torch.cat([posterior_variance, posterior_variance_raw], dim=1)
    #         posterior_log_variance = torch.cat([posterior_log_variance, posterior_log_variance_raw], dim=1)

    #     return model_mean, posterior_variance, posterior_log_variance

    # @torch.no_grad()
    # def p_sample(self, x, x_prev, t, clip_denoised=True, repeat_noise=False):
    #     b, *_, device = *x.shape, x.device
    #     model_mean, _, model_log_variance = self.p_mean_variance(x=x, x_prev=x_prev, t=t, clip_denoised=clip_denoised)
    #     if self.ensemble:
    #         # use the same noise for both tensors
    #         x_half, _ = x.chunk(2, dim=1)
    #         noise = noise_like(x_half.shape, device, repeat_noise)
    #         noise = torch.cat([noise, noise], dim=1)
    #     else:
    #         noise = noise_like(x.shape, device, repeat_noise)
    #     # no noise when t == 0
    #     nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    #     return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    # @torch.no_grad()
    # def p_sample_loop(self, shape, x_prev, return_intermediates=False):
    #     device = self.betas.device
    #     b = shape[0]
    #     img = torch.randn(shape, device=device)
    #     if self.ensemble:
    #         img = torch.cat([img, img], dim=1)
    #     intermediates = [img]
    #     for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
    #         img = self.p_sample(img, x_prev, torch.full((b,), i, device=device, dtype=torch.long),
    #                             clip_denoised=self.clip_denoised)
    #         if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
    #             intermediates.append(img)
    #     if self.ensemble:
    #         _, img = img.chunk(2, dim=1)
    #         # for idx in range(len(intermediates)):
    #         #     _, intermediates[idx] = intermediates[idx].chunk(2, dim=1)
    #     if return_intermediates:
    #         return img, intermediates
    #     return img

    # @torch.no_grad()
    # def sample(self, x_prev, batch_size=16, return_intermediates=False):
    #     image_size = self.image_size
    #     channels = self.channels
    #     return self.p_sample_loop(x_prev.shape,
    #                               x_prev=x_prev,
    #                               return_intermediates=return_intermediates)

    def raw_q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
    
    def get_prev_ema(self, x_prev, sqrt_thetas, sqrt_one_minus_thetas):
        x_prev_ema = x_prev[:, 0]
        for i in range(1, x_prev.shape[1]):
            x_prev_ema = x_prev_ema * sqrt_one_minus_thetas + x_prev[:, i] * sqrt_thetas
            if self.rescale_ema:
                x_prev_ema = x_prev_ema / (sqrt_one_minus_thetas + sqrt_thetas)
        return x_prev_ema
    
    def get_emas(self, x, sqrt_thetas, sqrt_one_minus_thetas):
        x = rearrange(x, 'b (s c) h w -> b s c h w', c=self.z_channels)
        x_emas = [x[:, 0]]
        for i in range(1, x.shape[1]):
            x_emas.append(x_emas[-1] * sqrt_one_minus_thetas + x[:, i] * sqrt_thetas)
        x_emas = torch.cat(x_emas, dim=1)
        return x_emas

    def get_reverse_emas(self, x_emas, sqrt_thetas, sqrt_one_minus_thetas):
        x_emas = rearrange(x_emas, 'b (s c) h w -> b s c h w', c=self.z_channels)
        xs = [x_emas[:, 0]]
        for i in range(1, x_emas.shape[1]):
            xs.append((x_emas[:, i] - x_emas[:, i-1] * sqrt_one_minus_thetas) / sqrt_thetas)
        xs = torch.cat(xs, dim=1)
        return xs
    
    def q_sample(self, x_start, x_prev, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        # x_start = rearrange(x_start, 'b (s c) h w -> b s c h w', c=self.z_channels)
        # x_prev = rearrange(x_prev, 'b (s c) h w -> b s c h w', c=self.z_channels)
        # noise = rearrange(noise, 'b (s c) h w -> b s c h w', c=self.z_channels)
        if self.use_x_ema:
            x_start = self.get_emas(x_start, extract_into_tensor(self.sqrt_thetas_cumprod, t, x_start[:, :1].shape),
                                    extract_into_tensor(self.sqrt_one_minus_thetas_cumprod, t, x_start[:, :1].shape))
        x_t = self.raw_q_sample(x_start, t, noise)
        return x_t
    
    def raw_get_v(self, x, noise, t):
        return (
                extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
        )
    
    def get_v(self, x, x_prev, noise, t):
        return self.raw_get_v(x, noise, t)
    
    # def get_v_standard(self, x, x_prev, noise, t):
    #     return (
    #             extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
    #             extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
    #     )

    def get_loss(self, pred, target, mean=True):
        if self.loss_type == 'l1':
            loss = (target - pred).abs()
            if mean:
                loss = loss.mean()
        elif self.loss_type == 'l2':
            if mean:
                loss = torch.nn.functional.mse_loss(target, pred)
            else:
                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
        else:
            raise NotImplementedError("unknown loss type '{loss_type}'")

        return loss

    def p_losses(self, x_start, x_prev, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, x_prev=x_prev, t=t, noise=noise)
        model_out = self.model(x_noisy, t)

        loss_dict = {}
        if self.parameterization == "eps":
            target = noise
        elif self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "v":
            target = self.get_v(x=x_start, x_prev=x_prev, noise=noise, t=t)
        elif self.parameterization == "v_standard":
            target = self.get_v_standard(x=x_start, x_prev=x_prev, noise=noise, t=t)
        else:
            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")

        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])

        log_prefix = 'train' if self.training else 'val'

        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
        loss_simple = loss.mean() * self.l_simple_weight

        # loss_vlb = (self.lvlb_weights[t] * loss).mean()
        # loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})

        loss = loss_simple

        loss_dict.update({f'{log_prefix}/loss': loss})

        return loss, loss_dict

    def forward(self, x, x_prev, *args, **kwargs):
        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        return self.p_losses(x, x_prev, t, *args, **kwargs)

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = rearrange(x, 'b h w c -> b c h w')
        x = x.to(memory_format=torch.contiguous_format).float()
        
        return x

    def shared_step(self, batch):
        x = self.get_input(batch, self.first_stage_key)
        x_prev = self.get_input(batch, self.first_stage_key_prev)
        loss, loss_dict = self(x, x_prev)
        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        for k in self.ucg_training:
            p = self.ucg_training[k]["p"]
            val = self.ucg_training[k]["val"]
            if val is None:
                val = ""
            for i in range(len(batch[k])):
                if self.ucg_prng.choice(2, p=[1 - p, p]):
                    batch[k][i] = val

        loss, loss_dict = self.shared_step(batch)

        self.log_dict(loss_dict, prog_bar=True,
                      logger=True, on_step=True, on_epoch=True)

        self.log("global_step", self.global_step,
                 prog_bar=True, logger=True, on_step=True, on_epoch=False)

        if self.use_scheduler:
            lr = self.optimizers().param_groups[0]['lr']
            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        _, loss_dict_no_ema = self.shared_step(batch)
        with self.ema_scope():
            _, loss_dict_ema = self.shared_step(batch)
            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def _get_rows_from_list(self, samples):
        n_imgs_per_row = len(samples)
        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
        log = dict()
        x = self.get_input(batch, self.first_stage_key)
        x_prev = self.get_input(batch, self.first_stage_key_prev)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        x = x.to(self.device)[:N]
        x_prev = x_prev.to(self.device)[:N]
        log["inputs"] = x
        log["prev"] = x_prev

        # get diffusion row
        diffusion_row = list()
        x_start = x[:n_row]

        for t in range(self.num_timesteps):
            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                t = t.to(self.device).long()
                noise = torch.randn_like(x_start)
                x_noisy = self.q_sample(x_start=x_start, x_prev=x_prev, t=t, noise=noise)
                diffusion_row.append(x_noisy)

        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)

        if sample:
            # get denoise row
            with self.ema_scope("Plotting"):
                raise NotImplementedError
                samples, denoise_row = self.sample(x_prev=x_prev, batch_size=N, return_intermediates=True)

            log["samples"] = samples
            log["denoise_row"] = self._get_rows_from_list(denoise_row)

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.learn_logvar:
            params = params + [self.logvar]
        opt = torch.optim.AdamW(params, lr=lr)
        return opt


class DynamicalLatentDiffusion(DynamicalDPM):
    """main class"""

    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 num_timesteps_cond=None,
                 cond_stage_key="image",
                 cond_stage_trainable=False,
                 concat_mode=True,
                 cond_stage_forward=None,
                 conditioning_key=None,
                 scale_factor=1.0,
                 shift_factor=0.0,
                 scale_by_std=False,
                 force_null_conditioning=False,
                 *args, **kwargs):
        self.force_null_conditioning = force_null_conditioning
        self.num_timesteps_cond = default(num_timesteps_cond, 1)
        self.scale_by_std = scale_by_std
        assert self.num_timesteps_cond <= kwargs['timesteps']
        # for backwards compatibility after implementation of DiffusionWrapper
        if conditioning_key is None:
            conditioning_key = 'concat' if concat_mode else 'crossattn'
        if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        reset_ema = kwargs.pop("reset_ema", False)
        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
            self.num_downs = 0
        if type(shift_factor) is float:
            self.shift_factor = shift_factor
        else:
            self.shift_factor = torch.Tensor(shift_factor)
        if type(scale_factor) is float:
            self.scale_factor = scale_factor
        else:
            self.scale_factor = torch.Tensor(scale_factor)
        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)
        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False
        self.bbox_tokenizer = None

        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True
            if reset_ema:
                assert self.use_ema
                print(
                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)
        if reset_num_ema_updates:
            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
            assert self.use_ema
            self.model_ema.reset_num_updates()

    def make_cond_schedule(self, ):
        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
        self.cond_ids[:self.num_timesteps_cond] = ids

    @rank_zero_only
    @torch.no_grad()
    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
        # only for very first batch
        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
            # set rescale weight to 1./std of encodings
            print("### USING STD-RESCALING ###")
            x = super().get_input(batch, self.first_stage_key)
            x = x.to(self.device)
            if hasattr(self, 'batched_encode') and callable(self.batched_encode):
                encoder_posterior = self.batched_encode(x)
            else:
                encoder_posterior = self.encode_first_stage(x)
            z = self.get_first_stage_encoding(encoder_posterior).detach()
            del self.scale_factor
            self.register_buffer('scale_factor', 1. / z.flatten().std())
            print(f"setting self.scale_factor to {self.scale_factor}")
            print("### USING STD-RESCALING ###")

    def register_schedule(self, given_betas=None, given_gammas=None,
                          beta_schedule="linear", gamma_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3,
                          linear_start_gamma=1e-4, linear_end_gamma=2e-2, cosine_s_gamma=8e-3):
        super().register_schedule(given_betas, given_gammas, beta_schedule, gamma_schedule, timesteps,
                                  linear_start, linear_end, cosine_s, linear_start_gamma, linear_end_gamma, cosine_s_gamma)

        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def instantiate_cond_stage(self, config):
        if not self.cond_stage_trainable:
            if config == "__is_first_stage__":
                print("Using first stage also as cond stage.")
                self.cond_stage_model = self.first_stage_model
            elif config == "__is_unconditional__":
                print(f"Training {self.__class__.__name__} as an unconditional model.")
                self.cond_stage_model = None
                # self.be_unconditional = True
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            assert config != '__is_first_stage__'
            assert config != '__is_unconditional__'
            model = instantiate_from_config(config)
            self.cond_stage_model = model

    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
        denoise_row = []
        for zd in tqdm(samples, desc=desc):
            denoise_row.append(self.decode_first_stage(zd.to(self.device),
                                                       force_not_quantize=force_no_decoder_quantization))
        n_imgs_per_row = len(denoise_row)
        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
            z = encoder_posterior.sample()
        elif isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        if type(self.shift_factor) is float:
            z = z + self.shift_factor
        else:
            self.shift_factor = self.shift_factor.to(z.device)
            z = z + self.shift_factor[None,:,None,None]
        if type(self.scale_factor) is float:
            z = z * self.scale_factor
        else:
            self.scale_factor = self.scale_factor.to(z.device)
            z = z * self.scale_factor[None,:,None,None]
        return z
        # return self.scale_factor * (z + self.shift_factor)

    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
                c = self.cond_stage_model.encode(c)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c)
        else:
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c

    # def meshgrid(self, h, w):
    #     y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
    #     x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)

    #     arr = torch.cat([y, x], dim=-1)
    #     return arr

    # def delta_border(self, h, w):
    #     """
    #     :param h: height
    #     :param w: width
    #     :return: normalized distance to image border,
    #      wtith min distance = 0 at border and max dist = 0.5 at image center
    #     """
    #     lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
    #     arr = self.meshgrid(h, w) / lower_right_corner
    #     dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
    #     dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
    #     edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
    #     return edge_dist

    # def get_weighting(self, h, w, Ly, Lx, device):
    #     weighting = self.delta_border(h, w)
    #     weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
    #                            self.split_input_params["clip_max_weight"], )
    #     weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)

    #     if self.split_input_params["tie_braker"]:
    #         L_weighting = self.delta_border(Ly, Lx)
    #         L_weighting = torch.clip(L_weighting,
    #                                  self.split_input_params["clip_min_tie_weight"],
    #                                  self.split_input_params["clip_max_tie_weight"])

    #         L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
    #         weighting = weighting * L_weighting
    #     return weighting

    # def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
    #     """
    #     :param x: img of size (bs, c, h, w)
    #     :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
    #     """
    #     bs, nc, h, w = x.shape

    #     # number of crops in image
    #     Ly = (h - kernel_size[0]) // stride[0] + 1
    #     Lx = (w - kernel_size[1]) // stride[1] + 1

    #     if uf == 1 and df == 1:
    #         fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
    #         unfold = torch.nn.Unfold(**fold_params)

    #         fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)

    #         weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
    #         normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
    #         weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))

    #     elif uf > 1 and df == 1:
    #         fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
    #         unfold = torch.nn.Unfold(**fold_params)

    #         fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
    #                             dilation=1, padding=0,
    #                             stride=(stride[0] * uf, stride[1] * uf))
    #         fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)

    #         weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
    #         normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
    #         weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))

    #     elif df > 1 and uf == 1:
    #         fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
    #         unfold = torch.nn.Unfold(**fold_params)

    #         fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
    #                             dilation=1, padding=0,
    #                             stride=(stride[0] // df, stride[1] // df))
    #         fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)

    #         weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
    #         normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
    #         weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))

    #     else:
    #         raise NotImplementedError

    #     return fold, unfold, normalization, weighting

    @torch.no_grad()
    def get_input(self, batch, k, k_prev, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
        x = super().get_input(batch, k)
        x_prev = super().get_input(batch, k_prev)
        if bs is not None:
            x = x[:bs]
            x_prev = x_prev[:bs]
        x = x.to(self.device)
        x_prev = x_prev.to(self.device)

        encoder_posterior = self.encode_first_stage(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()

        encoder_posterior_prev = self.encode_first_stage(x_prev)
        z_prev = self.get_first_stage_encoding(encoder_posterior_prev).detach()

        if self.model.conditioning_key is not None and not self.force_null_conditioning:
            if cond_key is None:
                cond_key = self.cond_stage_key
            if cond_key != self.first_stage_key:
                if cond_key in ['caption', 'coordinates_bbox', "txt"]:
                    xc = batch[cond_key]
                elif cond_key in ['class_label', 'cls']:
                    xc = batch
                else:
                    xc = super().get_input(batch, cond_key).to(self.device)
            else:
                xc = x
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    c = self.get_learned_conditioning(xc)
                else:
                    c = self.get_learned_conditioning(xc.to(self.device))
            else:
                c = xc
            if bs is not None:
                c = c[:bs]

            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                ckey = __conditioning_keys__[self.model.conditioning_key]
                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}

        else:
            c = None
            xc = None
            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                c = {'pos_x': pos_x, 'pos_y': pos_y}
        out = [z, z_prev, c]
        if return_first_stage_outputs:
            xrec = self.decode_first_stage(z)
            out.extend([x, xrec])
        if return_x:
            out.extend([x])
        if return_original_cond:
            out.append(xc)
        return out

    @torch.no_grad()
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if predict_cids:
            if z.dim() == 4:
                z = torch.argmax(z.exp(), dim=1).long()
            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
            z = rearrange(z, 'b h w c -> b c h w').contiguous()

        # z = 1. / self.scale_factor * z - self.shift_factor
        if type(self.scale_factor) is float:
            z = 1. / self.scale_factor * z
        else:
            self.scale_factor = self.scale_factor.to(z.device)
            z = z / self.scale_factor[None,:,None,None]
        if type(self.shift_factor) is float:
            z = z - self.shift_factor
        else:
            self.shift_factor = self.shift_factor.to(z.device)
            z = z - self.shift_factor[None,:,None,None]
        return self.first_stage_model.decode(z)

    @torch.no_grad()
    def encode_first_stage(self, x):
        return self.first_stage_model.encode(x)

    def shared_step(self, batch, **kwargs):
        x, x_prev, c = self.get_input(batch, self.first_stage_key, self.first_stage_key_prev)
        loss = self(x, x_prev, c)
        return loss

    def forward(self, x, x_prev, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, x_prev, c, t, *args, **kwargs)

    def apply_model(self, x_noisy, t, cond, return_ids=False):
        if isinstance(cond, dict):
            # hybrid case, cond is expected to be a dict
            pass
        else:
            if not isinstance(cond, list):
                cond = [cond]
            if 'concat' in self.model.conditioning_key:
                key = 'c_concat'
            elif self.model.conditioning_key == 'crossattn':
                key = 'c_crossattn'
            elif self.model.conditioning_key == 'mcvd':
                key = 'c_mcvd' 
            cond = {key: cond}

        x_recon = self.model(x_noisy, t, **cond)

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    # def _prior_bpd(self, x_start):
    #     """
    #     Get the prior KL term for the variational lower-bound, measured in
    #     bits-per-dim.
    #     This term can't be optimized, as it only depends on the encoder.
    #     :param x_start: the [N x C x ...] tensor of inputs.
    #     :return: a batch of [N] KL values (in bits), one per batch element.
    #     """
    #     batch_size = x_start.shape[0]
    #     t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
    #     qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
    #     kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
    #     return mean_flat(kl_prior) / np.log(2.0)

    def p_losses(self, x_start, x_prev, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        if self.use_x_ema != 'only':
            noise = self.get_emas(noise, extract_into_tensor(self.sqrt_thetas_cumprod, t, x_start[:, :1].shape),
                                extract_into_tensor(self.sqrt_one_minus_thetas_cumprod, t, x_start[:, :1].shape))
        x_noisy = self.q_sample(x_start=x_start, x_prev=x_prev, t=t, noise=noise)
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        elif self.parameterization == "v":
            target = self.get_v(x=x_start, x_prev=x_prev, noise=noise, t=t)
        elif self.parameterization == "v_standard":
            target = self.get_v_standard(x=x_start, x_prev=x_prev, noise=noise, t=t)
        else:
            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
        

        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        logvar_t = self.logvar[t].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        # loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        # loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        # loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        # loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict

    # def p_mean_variance(self, x, x_prev, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
    #                     return_x0=False, score_corrector=None, corrector_kwargs=None, 
    #                     unconditional_guidance_scale=1., unconditional_conditioning=None,):
    #     if unconditional_conditioning is None or unconditional_guidance_scale == 1.:  # unconditional
    #         model_out = self.apply_model(x, t, c, return_ids=return_codebook_ids)
    #     else:
    #         x_in = torch.cat([x] * 2)
    #         t_in = torch.cat([t] * 2)
    #         c_in = torch.cat([unconditional_conditioning, c])
    #         model_uncond, model_t = self.apply_model(x_in, t_in, c_in, return_ids=return_codebook_ids).chunk(2)
    #         model_out = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

    #     if score_corrector is not None:
    #         assert self.parameterization == "eps"
    #         model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)

    #     if return_codebook_ids:
    #         model_out, logits = model_out
        
    #     if self.ensemble:
    #         x, x_raw = x.chunk(2, dim=1)
    #         if self.parameterization == "eps":
    #             x_recon_raw = self.raw_predict_start_from_noise(x_raw, t=t, noise=model_out)
    #         elif self.parameterization == "x0":
    #             x_recon_raw = model_out
    #         elif self.parameterization == "v_standard":
    #             x_recon_raw = self.predict_start_from_z_and_v_standard(x, x_prev, t, model_out)
    #         if clip_denoised:
    #             x_recon_raw.clamp_(-1., 1.)
    #         if quantize_denoised:
    #             x_recon_raw, _, [_, _, indices] = self.first_stage_model.quantize(x_recon_raw)
    #         model_mean_raw, posterior_variance_raw, posterior_log_variance_raw = self.q_posterior(x_start=x_recon_raw, x_t=x_raw, x_prev=x_prev, t=t)

    #     if self.parameterization == "eps":
    #         x_recon = self.predict_start_from_noise(x, x_prev, t=t, noise=model_out)
    #     elif self.parameterization == "x0":
    #         x_recon = model_out
    #     elif self.parameterization == "v":
    #         x_recon = self.predict_start_from_z_and_v(x, x_prev, t=t, v=model_out)
    #     elif self.parameterization == "v_standard":
    #         x_recon = self.predict_start_from_z_and_v_standard(x, x_prev, t=t, v=model_out)
    #     else:
    #         raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")

    #     if clip_denoised:
    #         x_recon.clamp_(-1., 1.)
    #     if quantize_denoised:
    #         x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
    #     model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_prev=x_prev, x_t=x, t=t)
        
    #     if self.ensemble:
    #         model_mean = torch.cat([model_mean, model_mean_raw], dim=1)
    #         posterior_variance = torch.cat([posterior_variance, posterior_variance_raw], dim=1)
    #         posterior_log_variance = torch.cat([posterior_log_variance, posterior_log_variance_raw], dim=1)
        
    #     if return_codebook_ids:
    #         return model_mean, posterior_variance, posterior_log_variance, logits
    #     elif return_x0:
    #         return model_mean, posterior_variance, posterior_log_variance, x_recon
    #     else:
    #         return model_mean, posterior_variance, posterior_log_variance

    # @torch.no_grad()
    # def p_sample(self, x, x_prev, c, t, clip_denoised=False, repeat_noise=False,
    #              return_codebook_ids=False, quantize_denoised=False, return_x0=False,
    #              temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, **kwargs):
    #     b, *_, device = *x.shape, x.device
    #     outputs = self.p_mean_variance(x=x, x_prev=x_prev, c=c, t=t, clip_denoised=clip_denoised,
    #                                    return_codebook_ids=return_codebook_ids,
    #                                    quantize_denoised=quantize_denoised,
    #                                    return_x0=return_x0,
    #                                    score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs)
    #     if return_codebook_ids:
    #         raise DeprecationWarning("Support dropped.")
    #         model_mean, _, model_log_variance, logits = outputs
    #     elif return_x0:
    #         model_mean, _, model_log_variance, x0 = outputs
    #     else:
    #         model_mean, _, model_log_variance = outputs

    #     if self.ensemble:
    #         # use the same noise for both tensors
    #         x_half, _ = x.chunk(2, dim=1)
    #         noise = noise_like(x_half.shape, device, repeat_noise) * temperature
    #         if noise_dropout > 0.:
    #             noise = torch.nn.functional.dropout(noise, p=noise_dropout)
    #         noise = torch.cat([noise, noise], dim=1)
    #     else:
    #         noise = noise_like(x.shape, device, repeat_noise) * temperature
    #         if noise_dropout > 0.:
    #             noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        
    #     # no noise when t == 0
    #     nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

    #     if return_codebook_ids:
    #         return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
    #     if return_x0:
    #         return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
    #     else:
    #         return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    # @torch.no_grad()
    # def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
    #                           img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
    #                           score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
    #                           log_every_t=None):
    #     if not log_every_t:
    #         log_every_t = self.log_every_t
    #     timesteps = self.num_timesteps
    #     if batch_size is not None:
    #         b = batch_size if batch_size is not None else shape[0]
    #         shape = [batch_size] + list(shape)
    #     else:
    #         b = batch_size = shape[0]
    #     if x_T is None:
    #         img = torch.randn(shape, device=self.device)
    #     else:
    #         img = x_T
    #     intermediates = []
    #     if cond is not None:
    #         if isinstance(cond, dict):
    #             cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
    #             list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
    #         else:
    #             cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

    #     if start_T is not None:
    #         timesteps = min(timesteps, start_T)
    #     iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
    #                     total=timesteps) if verbose else reversed(
    #         range(0, timesteps))
    #     if type(temperature) == float:
    #         temperature = [temperature] * timesteps

    #     for i in iterator:
    #         ts = torch.full((b,), i, device=self.device, dtype=torch.long)
    #         if self.shorten_cond_schedule:
    #             assert self.model.conditioning_key != 'hybrid'
    #             tc = self.cond_ids[ts].to(cond.device)
    #             cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

    #         img, x0_partial = self.p_sample(img, cond, ts,
    #                                         clip_denoised=self.clip_denoised,
    #                                         quantize_denoised=quantize_denoised, return_x0=True,
    #                                         temperature=temperature[i], noise_dropout=noise_dropout,
    #                                         score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
    #         if mask is not None:
    #             assert x0 is not None
    #             img_orig = self.q_sample(x0, ts)
    #             img = img_orig * mask + (1. - mask) * img

    #         if i % log_every_t == 0 or i == timesteps - 1:
    #             intermediates.append(x0_partial)
    #         if callback: callback(i)
    #         if img_callback: img_callback(img, i)
    #     return img, intermediates

    # @torch.no_grad()
    # def p_sample_loop(self, cond, shape, x_prev, return_intermediates=False,
    #                   x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
    #                   mask=None, x0=None, img_callback=None, start_T=None,
    #                   log_every_t=None, **kwargs):

    #     if not log_every_t:
    #         log_every_t = self.log_every_t
    #     device = self.betas.device
    #     b = shape[0]
    #     if x_T is None:
    #         img = torch.randn(shape, device=device)
    #     else:
    #         img = x_T
    #     if self.ensemble:
    #         img = torch.cat([img, img], dim=1)

    #     intermediates = [img]
    #     if timesteps is None:
    #         timesteps = self.num_timesteps

    #     if start_T is not None:
    #         timesteps = min(timesteps, start_T)
    #     iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
    #         range(0, timesteps))

    #     if mask is not None:
    #         assert x0 is not None
    #         assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match

    #     for i in iterator:
    #         ts = torch.full((b,), i, device=device, dtype=torch.long)
    #         if self.shorten_cond_schedule:
    #             assert self.model.conditioning_key != 'hybrid'
    #             tc = self.cond_ids[ts].to(cond.device)
    #             cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

    #         img = self.p_sample(img, x_prev, cond, ts,
    #                             clip_denoised=self.clip_denoised,
    #                             quantize_denoised=quantize_denoised, **kwargs)
    #         if mask is not None:
    #             img_orig = self.q_sample(x0, x_prev, ts)
    #             img = img_orig * mask + (1. - mask) * img

    #         if i % log_every_t == 0 or i == timesteps - 1:
    #             intermediates.append(img)
    #         if callback: callback(i)
    #         if img_callback: img_callback(img, i)
        
    #     if self.ensemble:
    #         _, img = img.chunk(2, dim=1)
    #         # for idx in range(len(intermediates)):
    #         #     _, intermediates[idx] = intermediates[idx].chunk(2, dim=1)

    #     if return_intermediates:
    #         return img, intermediates
    #     return img

    # @torch.no_grad()
    # def sample(self, x_prev, cond, batch_size=16, return_intermediates=False, x_T=None,
    #            verbose=True, timesteps=None, quantize_denoised=False,
    #            mask=None, x0=None, shape=None, **kwargs):
    #     if shape is None:
    #         shape = (batch_size, self.channels, self.image_size, self.image_size)
    #     if cond is not None:
    #         if isinstance(cond, dict):
    #             cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
    #             list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
    #         else:
    #             cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
    #     return self.p_sample_loop(cond,
    #                               shape,
    #                               x_prev,
    #                               return_intermediates=return_intermediates, x_T=x_T,
    #                               verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
    #                               mask=mask, x0=x0, **kwargs)

    @torch.no_grad()
    def sample_log(self, x_prev, cond, batch_size, ddim, ddim_steps, unconditional_guidance_scale=1.,
                                                 unconditional_conditioning=None, **kwargs):
        if ddim:
            # raise NotImplementedError("DDIM not yet implemented for DynamicalLatentDiffusion")
            ddim_sampler = DDIMSampler(self)
            shape = (self.z_channels * self.model.T, x_prev.shape[-2], x_prev.shape[-1])
            samples, intermediates_pred_x0 = ddim_sampler.sample(ddim_steps, batch_size,
                                                         shape, x_prev, cond, verbose=False,
                                                         unconditional_guidance_scale=unconditional_guidance_scale,
                                                         unconditional_conditioning=unconditional_conditioning,
                                                         log_every_t=self.log_every_t)
            intermediates = intermediates_pred_x0["x_inter"]
            pred_x0 = intermediates_pred_x0["pred_x0"]

        else:
            raise NotImplementedError
            samples, intermediates = self.sample(x_prev=x_prev, cond=cond, batch_size=batch_size,
                                                 return_intermediates=True, unconditional_guidance_scale=unconditional_guidance_scale,
                                                 unconditional_conditioning=unconditional_conditioning)

        return samples, (intermediates, pred_x0)

    @torch.no_grad()
    def get_unconditional_conditioning(self, batch_size, null_label=None):
        if null_label is not None:
            xc = null_label
            if isinstance(xc, ListConfig):
                xc = list(xc)
            if isinstance(xc, dict) or isinstance(xc, list):
                c = self.get_learned_conditioning(xc)
            else:
                if hasattr(xc, "to"):
                    xc = xc.to(self.device)
                c = self.get_learned_conditioning(xc)
        else:
            if self.cond_stage_key in ["class_label", "cls"]:
                xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
                return self.get_learned_conditioning(xc)
            else:
                raise NotImplementedError("todo")
        if isinstance(c, list):  # in case the encoder gives us a list
            for i in range(len(c)):
                c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
        else:
            c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
        return c

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        ema_scope = self.ema_scope if use_ema_scope else nullcontext
        use_ddim = ddim_steps is not None

        log = dict()
        z, z_prev, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, self.first_stage_key_prev,
                                           return_first_stage_outputs=True,
                                           force_c_encode=True,
                                           return_original_cond=True,
                                           bs=N)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        log["inputs"] = x
        log["reconstruction"] = xrec
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
                log["conditioning"] = xc
            elif self.cond_stage_key in ['class_label', "cls"]:
                try:
                    xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
                    log['conditioning'] = xc
                except KeyError:
                    # probably no "human_label" in batch
                    pass
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            with ema_scope("Sampling"):
                samples, (z_denoise_row, _) = self.sample_log(x_prev=z_prev, cond=c, batch_size=N, ddim=use_ddim,
                                                         ddim_steps=ddim_steps, eta=ddim_eta)
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
                    self.first_stage_model, IdentityFirstStage):
                # also display when quantizing x0 while sampling
                with ema_scope("Plotting Quantized Denoised"):
                    samples, (z_denoise_row, _) = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                             ddim_steps=ddim_steps, eta=ddim_eta,
                                                             quantize_denoised=True)
                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
                    #                                      quantize_denoised=True)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_x0_quantized"] = x_samples

        if unconditional_guidance_scale > 1.0:
            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
            if self.model.conditioning_key == "crossattn-adm":
                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
            with ema_scope("Sampling with classifier-free guidance"):
                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                 ddim_steps=ddim_steps, eta=ddim_eta,
                                                 unconditional_guidance_scale=unconditional_guidance_scale,
                                                 unconditional_conditioning=uc,
                                                 )
                x_samples_cfg = self.decode_first_stage(samples_cfg)
                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        if inpaint:
            # make a simple center square
            b, h, w = z.shape[0], z.shape[2], z.shape[3]
            mask = torch.ones(N, h, w).to(self.device)
            # zeros will be filled in
            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
            mask = mask[:, None, ...]
            with ema_scope("Plotting Inpaint"):
                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
            x_samples = self.decode_first_stage(samples.to(self.device))
            log["samples_inpainting"] = x_samples
            log["mask"] = mask

            # outpaint
            mask = 1. - mask
            with ema_scope("Plotting Outpaint"):
                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
            x_samples = self.decode_first_stage(samples.to(self.device))
            log["samples_outpainting"] = x_samples

        if plot_progressive_rows:
            with ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list(self.cond_stage_model.parameters())
        if self.learn_logvar:
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

    @torch.no_grad()
    def to_rgb(self, x):
        x = x.float()
        if not hasattr(self, "colorize"):
            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
        x = nn.functional.conv2d(x, weight=self.colorize)
        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
        return x


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        # print(f"Diffusion model config: {diff_model_config}")
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.T = getattr(self.diffusion_model, 'num_video_frames', None)
        print(f"Diffusion model: {type(self.diffusion_model).__name__}")
        self.conditioning_key = conditioning_key
        # assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm','mcvd']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_mcvd: list = None, c_adm=None):
        
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        
        elif self.conditioning_key.startswith('concat-video-mask'):
            B, TC, H, W = x.shape
            C = TC // self.T
            mask = torch.zeros(B, self.T, H, W).to(x.device)
            B, C_c, H, W = c_concat[0].shape
            T_c = C_c // C
            mask[:, :T_c] = 1.
            c_concat_all = torch.zeros_like(x)
            c_concat_all[:, :C_c] = c_concat[0]
            x = rearrange(x, 'B (T C) H W -> (B T) C H W', T=self.T)
            c_concat_all = rearrange(c_concat_all, 'B (T C) H W -> (B T) C H W', T=self.T)
            mask = rearrange(mask, 'B T H W -> (B T) 1 H W', T=self.T)
            xc = torch.cat([x, c_concat_all, mask], dim=1)
            if '1st' in self.conditioning_key:
                mask_1st = torch.zeros(B, self.T, H, W).to(x.device)
                mask_1st[:, T_c] = 1.
                mask_1st = rearrange(mask_1st, 'B T H W -> (B T) 1 H W', T=self.T)
                xc = torch.cat([xc, mask_1st], dim=1)
            if 'action' in self.conditioning_key:
                actions = rearrange(c_concat[1], 'B T C -> (B T) C')
                actions = torch.unsqueeze(torch.unsqueeze(actions, -1), -1)
                actions = actions.repeat(1, 1, H, W)
                xc = torch.cat([xc, actions], dim=1)
            t = torch.repeat_interleave(t, self.T, dim=0)
            out = self.diffusion_model(xc, t)
            out = rearrange(out, '(B T) C H W -> B (T C) H W', T = self.T)
        
        elif self.conditioning_key == 'concat-video':
            # print('dpm wrapper', x.shape, c_concat[0].shape)
            x = rearrange(x, 'B (T C) H W -> (B T) C H W', T=self.T)
            c_concat = torch.repeat_interleave(c_concat[0], self.T, dim=0)
            xc = torch.cat([x, c_concat], dim=1)
            # print('xc', xc.shape)
            # print('t', t.shape)
            t = torch.repeat_interleave(t, self.T, dim=0)
            out = self.diffusion_model(xc, t)
            out = rearrange(out, '(B T) C H W -> B (T C) H W', T=self.T)
        
        elif self.conditioning_key == 'concat':
            # print('dpm wrapper', x.shape, c_concat[0].shape)
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        elif self.conditioning_key == 'mcvd':   # only for tensor-based conditioning
            cc = c_mcvd[0]
            # print('cc', cc.shape)
            # out = self.diffusion_model(x, t, cond=None)
            out = self.diffusion_model(x, t, cond=cc)
        else:
            raise NotImplementedError()

        return out


class DynamicalLDMCleanedWithEncoderCondition(DynamicalLatentDiffusion):
    def __init__(self, *args,
                 random_uncond=False,
                 x_channels=1,
                 z_channels=4,
                 uc_with_prev=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.x_channels = x_channels
        self.z_channels = z_channels
        self.random_uncond = random_uncond
        self.uc_with_prev = uc_with_prev

    def batched_encode(self, x):
        # x: (b, tc, h, w)
        t = x.shape[1] // self.x_channels
        encoder_posterior = self.encode_first_stage(rearrange(x, 'b (t c) h w -> (b t) c h w', t=t, c=self.x_channels))
        z = rearrange(self.get_first_stage_encoding(encoder_posterior).detach(), '(b c1) c2 h w -> b (c1 c2) h w', b=x.shape[0])
        return z
    
    def batched_decode(self, z):
        # z: (b, tc, h, w)
        t = z.shape[1] // self.z_channels
        x_rec = self.decode_first_stage(rearrange(z, 'b (t c) h w -> (b t) c h w', t=t, c=self.z_channels))
        x_rec = rearrange(x_rec, '(b t) c h w -> b (t c) h w', b=z.shape[0], t=t, c=self.x_channels)
        return x_rec

    def shared_step(self, batch, uncond=False, **kwargs):
        x, x_prev, c = self.get_input(batch, self.first_stage_key, self.first_stage_key_prev)
        if self.random_uncond:
            # uncond with probability 0.1
            uncond = torch.rand(1) < 0.1
        if uncond:
            c = self.get_unconditional_conditioning(batch)
        loss = self(x, x_prev, c)
        return loss
    
    @torch.no_grad()
    def get_unconditional_conditioning(self, batch):
        c = super(DynamicalLatentDiffusion, self).get_input(batch, self.cond_stage_key).to(self.device)
        c = self.batched_encode(c)
        # uc is noise with the same shape as c
        if self.uc_with_prev:
            c_non_prev, c_prev = c[:, :-self.z_channels], c[:, -self.z_channels:]
            uc = torch.cat([torch.randn_like(c_non_prev), c_prev], dim=1)
        else:
            uc = torch.randn_like(c)
        return uc
    
    @torch.no_grad()
    def get_input(self, batch, k, k_prev, log_mode=False):
        x = super(DynamicalLatentDiffusion, self).get_input(batch, k).to(self.device)
        x_prev = super(DynamicalLatentDiffusion, self).get_input(batch, k_prev).to(self.device)

        z, z_prev = self.batched_encode(x), self.batched_encode(x_prev)

        if self.model.conditioning_key is not None and not self.force_null_conditioning:
            x_c = super(DynamicalLatentDiffusion, self).get_input(batch, self.cond_stage_key).to(self.device)
            z_c = self.batched_encode(x_c)
        else:
            x_c, z_c = None, None

        out = [z, z_prev, z_c]
        
        if log_mode:
            x_rec = self.batched_decode(z)
            x_prev_rec = self.batched_decode(z_prev)
            x_c_rec = self.batched_decode(z_c)
            out.extend([x, x_prev, x_rec, x_prev_rec, x_c, x_c_rec])
        return out

    @torch.no_grad()
    def log_images(self, batch, ddim_steps=50, ddim_eta=0., use_ema_scope=True, **kwargs):
        ema_scope = self.ema_scope if use_ema_scope else nullcontext
        use_ddim = ddim_steps is not None

        log = dict()
        z, z_prev, z_c, x, x_prev, x_rec, x_prev_rec, x_c, x_c_rec = self.get_input(batch, self.first_stage_key, self.first_stage_key_prev, log_mode=True)
        
        log["inputs"] = x
        log["prev"] = x_prev
        log["cond"] = x_c
        log["inputs_rec"] = x_rec
        log["prev_rec"] = x_prev_rec
        log["cond_rec"] = x_c_rec

        # get denoise row
        with ema_scope("Sampling"):
            samples, (intermediates, pred_x0) = self.sample_log(x_prev=z_prev, cond=z_c, batch_size=z.shape[0], ddim=use_ddim,
                                                     ddim_steps=ddim_steps, eta=ddim_eta)
            # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
        x_samples = self.batched_decode(samples)
        log["samples"] = x_samples
        log["z_intermediates"] = intermediates
        log["z_pred_x0"] = pred_x0

        return log
