"""
This code contains minor edits from the original code at
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
to support sampling from the middle of the diffusion process with start_t and
start_image arguments.
"""

import torch as th
from improved_diffusion.respace import SpacedDiffusion
from improved_diffusion.respace import space_timesteps
from improved_diffusion.gaussian_diffusion import _extract_into_tensor
from improved_diffusion import gaussian_diffusion as gd


def create_gaussian_diffusion(
    *,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="",
):
    # Putting the class inside the function to avoid raising errors when the optional dependency improved_diffusion is
    # not installed. Waiting for generalimport to support this subclassing use case:
    # https://github.com/ManderaGeneral/generalimport/pull/28
    class SkippedSpacedDiffusion(SpacedDiffusion):
        def p_sample_loop(
            self,
            model,
            shape,
            noise=None,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            device=None,
            progress=False,
            start_t=0,
            start_image=None,
        ):
            """
            Generate samples from the model.

            :param model: the model module.
            :param shape: the shape of the samples, (N, C, H, W).
            :param noise: if specified, the noise from the encoder to sample.
                        Should be of the same shape as `shape`.
            :param clip_denoised: if True, clip x_start predictions to [-1, 1].
            :param denoised_fn: if not None, a function which applies to the
                x_start prediction before it is used to sample.
            :param model_kwargs: if not None, a dict of extra keyword arguments to
                pass to the model. This can be used for conditioning.
            :param device: if specified, the device to create the samples on.
                        If not specified, use a model parameter's device.
            :param progress: if True, show a tqdm progress bar.
            :return: a non-differentiable batch of samples.
            """
            final = None
            for sample in self.p_sample_loop_progressive(
                model,
                shape,
                noise=noise,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
                device=device,
                progress=progress,
                start_t=start_t,
                start_image=start_image,
            ):
                final = sample
            return final["sample"]

        def p_sample_loop_progressive(
            self,
            model,
            shape,
            noise=None,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            device=None,
            progress=False,
            start_t=0,
            start_image=None,
        ):
            """
            Generate samples from the model and yield intermediate samples from
            each timestep of diffusion.

            Arguments are the same as p_sample_loop().
            Returns a generator over dicts, where each dict is the return value of
            p_sample().
            """
            if device is None:
                device = next(model.parameters()).device
            assert isinstance(shape, (tuple, list))
            if noise is not None:
                img = noise
            else:
                img = th.randn(*shape, device=device)
            indices = list(range(self.num_timesteps))[::-1]
            indices = indices[start_t:]
            if start_image is not None:
                t_batch = th.tensor([indices[0]] * img.shape[0], device=device)
                img = self.q_sample(start_image, t=t_batch, noise=img)
            if progress:
                # Lazy import so that we don't depend on tqdm.
                from tqdm.auto import tqdm

                indices = tqdm(indices)

            for i in indices:
                t = th.tensor([i] * shape[0], device=device)
                with th.no_grad():
                    out = self.p_sample(
                        model,
                        img,
                        t,
                        clip_denoised=clip_denoised,
                        denoised_fn=denoised_fn,
                        model_kwargs=model_kwargs,
                    )
                    yield out
                    img = out["sample"]

        def ddim_sample(
            self,
            model,
            x,
            t,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            eta=0.0,
        ):
            """
            Sample x_{t-1} from the model using DDIM.

            Same usage as p_sample().
            """
            out = self.p_mean_variance(
                model,
                x,
                t,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
            )
            # Usually our model outputs epsilon, but we re-derive it
            # in case we used x_start or x_prev prediction.
            eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
            alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
            alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
            sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
            # Equation 12.
            noise = th.randn_like(x)
            mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
            nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))  # no noise when t == 0
            sample = mean_pred + nonzero_mask * sigma * noise
            return {"sample": sample, "pred_xstart": out["pred_xstart"]}

        def ddim_reverse_sample(
            self,
            model,
            x,
            t,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            eta=0.0,
        ):
            """
            Sample x_{t+1} from the model using DDIM reverse ODE.
            """
            assert eta == 0.0, "Reverse ODE only for deterministic path"
            out = self.p_mean_variance(
                model,
                x,
                t,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
            )
            # Usually our model outputs epsilon, but we re-derive it
            # in case we used x_start or x_prev prediction.
            eps = (
                _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
            ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
            alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)

            # Equation 12. reversed
            mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps

            return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}

        def ddim_sample_loop(
            self,
            model,
            shape,
            noise=None,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            device=None,
            progress=False,
            eta=0.0,
            start_t=0,
            start_image=None,
        ):
            """
            Generate samples from the model using DDIM.

            Same usage as p_sample_loop().
            """
            final = None
            for sample in self.ddim_sample_loop_progressive(
                model,
                shape,
                noise=noise,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
                device=device,
                progress=progress,
                eta=eta,
                start_t=start_t,
                start_image=start_image,
            ):
                final = sample
            return final["sample"]

        def ddim_sample_loop_progressive(
            self,
            model,
            shape,
            noise=None,
            clip_denoised=True,
            denoised_fn=None,
            model_kwargs=None,
            device=None,
            progress=False,
            eta=0.0,
            start_t=0,
            start_image=None,
        ):
            """
            Use DDIM to sample from the model and yield intermediate samples from
            each timestep of DDIM.

            Same usage as p_sample_loop_progressive().
            """
            if device is None:
                device = next(model.parameters()).device
            assert isinstance(shape, (tuple, list))
            if noise is not None:
                img = noise
            else:
                img = th.randn(*shape, device=device)
            indices = list(range(self.num_timesteps))[::-1]
            indices = indices[start_t:]
            if start_image is not None:
                t_batch = th.tensor([indices[0]] * img.shape[0], device=device)
                img = self.q_sample(start_image, t=t_batch, noise=img)
            if progress:
                # Lazy import so that we don't depend on tqdm.
                from tqdm.auto import tqdm

                indices = tqdm(indices)

            for i in indices:
                t = th.tensor([i] * shape[0], device=device)
                with th.no_grad():
                    out = self.ddim_sample(
                        model,
                        img,
                        t,
                        clip_denoised=clip_denoised,
                        denoised_fn=denoised_fn,
                        model_kwargs=model_kwargs,
                        eta=eta,
                    )
                    yield out
                    img = out["sample"]

    betas = gd.get_named_beta_schedule(noise_schedule, steps)
    if use_kl:
        loss_type = gd.LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE
    else:
        loss_type = gd.LossType.MSE
    if not timestep_respacing:
        timestep_respacing = [steps]
    return SkippedSpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
        model_var_type=(
            (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )
