import torch
import torch.nn.functional as F
import numpy as np
import gc

from diffusers import StableDiffusionPipeline
# from diffusers.utils import randn_tensor
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput


class EstimationErrPipeline(StableDiffusionPipeline):
    _optional_components = ["safety_checker", "feature_extractor"]

    def __init__(
        self,
        vae,
        text_encoder,
        tokenizer,
        unet,
        scheduler,
        safety_checker,
        feature_extractor,
        requires_safety_checker: bool = True,
    ):
        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            requires_safety_checker=requires_safety_checker,
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt=None,
        prompt2=None,
        height=None,
        width=None,
        num_inference_steps=50,
        guidance_scale=7.5,
        negative_prompt=None,
        num_images_per_prompt=1,
        eta=0.0,
        generator=None,
        latents=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        output_type="pil",
        return_dict=True,
        callback=None,
        callback_steps=1,
        cross_attention_kwargs=None,
        track_loss=False,
        skip_size = 1,
        lp=2,
    ):
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        # )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )
        if prompt2 is not None:
            prompt_embeds2 = self._encode_prompt(
                prompt2,
                device,
                num_images_per_prompt,
                do_classifier_free_guidance,
                negative_prompt,
                prompt_embeds=None,
                negative_prompt_embeds=negative_prompt_embeds,
            )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        if track_loss is True:
            prev_latents = [latents.clone().cpu()]
            post_latents = []
            est_latents = []
            ori_latents = []
            loss_res = []
            loss_est = []
            loss_ori = []
            loss_cond = []

            for i in range(len(latents)):
                loss_res.append([])
                loss_est.append([])
                loss_ori.append([])
                loss_cond.append([])

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = (
                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                )
                latent_model_input = self.scheduler.scale_model_input(
                    latent_model_input, t
                )

                if i >= 50 and prompt2 is not None: 
                    prompt_embeds = prompt_embeds2

                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                    pred_text_ori = self.scheduler.step(
                        noise_pred_text, t, t, latents, **extra_step_kwargs
                    ).pred_original_sample
                    pred_uncond_ori = self.scheduler.step(
                        noise_pred_uncond, t, t, latents, **extra_step_kwargs
                    ).pred_original_sample
                    for k in range(len(latents)):
                        loss_cond[k].append(
                            (pred_text_ori[k] - pred_uncond_ori[k]).norm(p=lp).item()
                        )
                    del pred_text_ori, pred_uncond_ori
                    # gc.collect()
                    # torch.cuda.empty_cache()

                    noise_pred_text = noise_pred_text - noise_pred_uncond

                    noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text

                # compute the previous noisy sample x_t -> x_t-1
                next_ts = timesteps[i+1] if i < len(timesteps) - 1 else -1
                next_latents = self.scheduler.step(
                    noise_pred, t, next_ts, latents, **extra_step_kwargs
                ).prev_sample

                # call the callback, if provided
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                ):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, next_latents)

                if track_loss is True:
                    if i >= skip_size:
                        # print(f"curr ts: {t} \t skip ts: {timesteps[i-skip_size]}")
                        ddim_step_latents = self.scheduler.step(
                            noise_pred, t, timesteps[i-skip_size], latents, **extra_step_kwargs
                        ).prev_sample

                        est_latents.append(ddim_step_latents.detach().clone().cpu())

                    ori_latent = self.scheduler.step(
                        noise_pred, t, next_ts, latents, **extra_step_kwargs
                    ).pred_original_sample
                    ori_latents.append(ori_latent.detach().clone().cpu())

                    post_latents.append(next_latents.detach().clone().cpu())
                    if i != len(timesteps) - 1:
                        prev_latents.append(next_latents.detach().clone().cpu())
                
                latents = next_latents

        # Compute loss at each timestep
        for i in range(len(prev_latents)):
            for j in range(len(latents)):
                loss_res[j].append(
                    (post_latents[i][j] - prev_latents[i][j]).norm(p=lp).item()
                )
        for i in range(len(est_latents)):
            for j in range(len(latents)):
                loss_est[j].append(
                    (est_latents[i][j] - prev_latents[i][j]).norm(p=lp).item()
                )
        temp = []
        for i in range(len(ori_latents)):
            if i < len(ori_latents) - skip_size:
                temp.append(ori_latents[i])
            if i >= skip_size:
                for j in range(len(latents)):
                    loss_ori[j].append(
                        (ori_latents[i][j] - temp[0][j]).norm(p=lp).item()
                    )
                temp.pop(0)


        if not output_type == "latent":
            image = self.vae.decode(
                latents / self.vae.config.scaling_factor, return_dict=False
            )[0]
            image, has_nsfw_concept = self.run_safety_checker(
                image, device, prompt_embeds.dtype
            )
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(
            image, output_type=output_type, do_denormalize=do_denormalize
        )

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        if track_loss is True:
            track_stats = {
                "post_latents": post_latents,
                "prev_latents": prev_latents,
                "est_latents": est_latents,
                "ori_latents": ori_latents,
                "loss_res": loss_res,
                "loss_est": loss_est,
                "loss_ori": loss_ori,
                "loss_cond": loss_cond
            }
            return (
                StableDiffusionPipelineOutput(
                    images=image, nsfw_content_detected=has_nsfw_concept
                ),
                track_stats,
            )
        else:
            return StableDiffusionPipelineOutput(
                images=image, nsfw_content_detected=has_nsfw_concept
            )


    @torch.no_grad()
    def mit_cfg_gen(
        self,
        prompt=None,
        skip_size = 1,
        max_guided_iters = 1,
        max_cg_iters = 1,
        cfg_lr = 1,
        cg_lr = 1,
        max_cfg = 10,
        threshold = 0.5,
        num_inference_steps=50,
        guidance_scale=7.5,
        height=None,
        width=None,
        negative_prompt=None,
        num_images_per_prompt=1,
        eta=0.0,
        generator=None,
        latents=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        output_type="pil",
        return_dict=True,
        callback=None,
        callback_steps=1,
        cross_attention_kwargs=None,
    ):
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        # )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        alphas_cumprod = self.scheduler.alphas_cumprod.numpy()
        prev_latents = [latents.clone().detach().cpu()]
        guided_iters = 0
        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                
                if i >= skip_size and guided_iters < max_guided_iters:
                    with torch.enable_grad():
                        latents.requires_grad_(True)
                        uncond_latents = latents.clone().detach()
                        uncond_latents.requires_grad_(False)
                        # expand the latents if we are doing classifier free guidance
                        latent_model_input = (
                            torch.cat([latents, uncond_latents]) if do_classifier_free_guidance else latents
                        )
                        latent_model_input = self.scheduler.scale_model_input(
                            latent_model_input, t
                        )

                        self.unet.zero_grad()

                        noise_pred = self.unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                            noise_pred_text = noise_pred_text - noise_pred_uncond

                            noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text
                        
                        ddim_step_latents = self.scheduler.step(
                            noise_pred, t, timesteps[i-skip_size], latents, **extra_step_kwargs
                        ).prev_sample

                        est_loss = F.mse_loss(ddim_step_latents, prev_latents[0].to(device), reduction="mean")
                        prev_latents.pop(0)

                        if est_loss >= threshold:
                            gradient = torch.autograd.grad(est_loss, latents)
                            added_cfg = (cfg_lr * (1 - alphas_cumprod[t]) ** 0.5 * gradient[0]).mean()
                            clipped_added_cfg = torch.clamp(abs(added_cfg), max=max_cfg)
                            
                            noise_pred -= clipped_added_cfg * noise_pred_text
                            # print(f"added cfg: {added_cfg}")
                            # print(f"added cfg: {clipped_added_cfg}")
                            if i < max_cg_iters:
                                res = cg_lr * (1 - alphas_cumprod[t]) ** 0.5 * gradient[0]
                                noise_pred -= res

                    guided_iters += 1
                else:
                    # expand the latents if we are doing classifier free guidance
                    latent_model_input = (
                        torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    )
                    latent_model_input = self.scheduler.scale_model_input(
                        latent_model_input, t
                    )

                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                        return_dict=False,
                    )[0]

                    # perform guidance
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                        noise_pred_text = noise_pred_text - noise_pred_uncond

                        noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text
                
                # compute the previous noisy sample x_t -> x_t-1
                next_ts = timesteps[i+1] if i < len(timesteps) - 1 else -1
                latents = self.scheduler.step(
                    noise_pred, t, next_ts, latents, **extra_step_kwargs
                ).prev_sample
                if guided_iters < max_guided_iters:
                    prev_latents.append(latents.detach().cpu())

                # call the callback, if provided
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                ):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)


        if not output_type == "latent":
            image = self.vae.decode(
                latents / self.vae.config.scaling_factor, return_dict=False
            )[0]
            image, has_nsfw_concept = self.run_safety_checker(
                image, device, prompt_embeds.dtype
            )
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(
            image.detach(), output_type=output_type, do_denormalize=do_denormalize
        )

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(
            images=image, nsfw_content_detected=has_nsfw_concept
        )

    

    @torch.no_grad()
    def mit_grad_prune_gen(
        self,
        prompt=None,
        skip_size = 1,
        cal_gradients_iters = 5,
        prune_rat = 1e-6,
        threshold = 0.5,
        num_inference_steps=50,
        guidance_scale=7.5,
        height=None,
        width=None,
        negative_prompt=None,
        num_images_per_prompt=1,
        eta=0.0,
        generator=None,
        latents=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        output_type="pil",
        return_dict=True,
        callback=None,
        callback_steps=1,
        cross_attention_kwargs=None,
    ):
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        # )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        def cal_num_zeros(matrix):
            return torch.sum(matrix == 0).item()
        
        def create_mask(matrix, num_bins=30):
            # # Step 1: Find the min and max values
            # min_val = torch.min(matrix).item()
            # max_val = torch.max(matrix).item()

            # # Step 2: Create 30 bins between min and max
            # bins = torch.linspace(min_val, max_val, num_bins+1, device=matrix.device)  # 30 bins means 31 edges

            # # Step 3: Find the bin index for each element
            # bin_indices = torch.bucketize(matrix, bins) - 1  # bucketize returns 1-based indices

            # # The bin nearest to the max value
            # nearest_bin = len(bins) - 2  # The second last bin index

            # # Step 4: Create the mask
            # mask = torch.ones_like(matrix, dtype=torch.int)  # Initialize mask with ones
            # mask[matrix == max_val] = 0  # Set mask to 0 for max value
            # mask[bin_indices == nearest_bin] = 0  # Set mask to 0 for elements in the bin nearest to the max value

            # Create a mask where positive elements are set to 0
            mask = torch.ones_like(matrix, device=matrix.device)
            mask[matrix > 0] = 0

            return mask

        # alphas_cumprod = self.scheduler.alphas_cumprod.numpy()
        prev_latents = [latents.clone().detach().cpu()]
        guided_iters = 0
        est_loss = torch.zeros(1, device=device)

        all_grads_input = {}
        all_grads_output = {}
        def get_grad(name):
            def get_output_hook(module, grad_input, grad_output):
                all_grads_input[name] = grad_input[0]
                all_grads_output[name] = grad_output[0]

            return get_output_hook

        def add_hook(net):
            for name, m in net.named_modules():
                # if any(s in name for s in ["to_q"]) and "down_blocks" in name: 
                if any(s in name for s in ["to_q", "to_v", "to_k"]) and "down_blocks" in name:
                    m.register_full_backward_hook(get_grad(name))
        add_hook(self.unet)

        masked=False
        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):

                if guided_iters == cal_gradients_iters and not masked:
                    est_loss /= guided_iters
                    est_loss.backward()
                    # for name, param in self.unet.named_parameters():
                    #     # if name in all_gradients.keys():
                    #     # if "to_k" in name or "to_v" in name:
                    #     if (torch.max(param.grad) > 1e-6) and ("to_k" in name or "to_v" in name):
                    #     # if torch.max(param.grad) > 1e-6:
                    #         all_gradients[name] = (param.grad * 1e6).cpu().numpy()
                    #         print(f"{name} shape: {all_gradients[name].shape} grad max {np.max(all_gradients[name])} grad min {np.min(all_gradients[name])}")
                    print(f"len grad: {len(all_grads_input)}")
                    return all_grads_input, all_grads_output

                    # for name, param in self.unet.named_parameters():
                    #     # if (torch.max(param.grad) > 1e-6) and ("to_k" in name or "to_v" in name):
                    #     # if torch.max(param.grad) > 0.0 and ("to_k" in name or "to_v" in name):
                    #     if torch.max(param.grad) > 0.0 and ("attention" in name):
                    #         print(f"Layer: {name}")
                    #         print(f"Num 0s before mask: {cal_num_zeros(param)}")
                    #         mask = create_mask(param.grad)
                    #         with torch.no_grad():
                    #             param.mul_(mask)
                    #         print(f"Num 0s after mask: {cal_num_zeros(param)}")

                    masked = True

                # expand the latents if we are doing classifier free guidance
                latent_model_input = (
                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                )
                latent_model_input = self.scheduler.scale_model_input(
                    latent_model_input, t
                )

                if i < skip_size:
                    with torch.enable_grad():
                        noise_pred = self.unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                            noise_pred_text = noise_pred_text - noise_pred_uncond

                            noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text

                elif i >= skip_size and guided_iters < cal_gradients_iters:
                    with torch.enable_grad():
                        self.unet.zero_grad()

                        noise_pred = self.unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                            noise_pred_text = noise_pred_text - noise_pred_uncond

                            noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text
                        
                        ddim_step_latents = self.scheduler.step(
                            noise_pred, t, timesteps[i-skip_size], latents, **extra_step_kwargs
                        ).prev_sample

                        est_loss += F.mse_loss(ddim_step_latents, prev_latents[0].to(device), reduction="mean")
                        prev_latents.pop(0)

                    guided_iters += 1
                else:
                    break
                    # noise_pred = self.unet(
                    #     latent_model_input,
                    #     t,
                    #     encoder_hidden_states=prompt_embeds,
                    #     cross_attention_kwargs=cross_attention_kwargs,
                    #     return_dict=False,
                    # )[0]

                    # # perform guidance
                    # if do_classifier_free_guidance:
                    #     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                    #     noise_pred_text = noise_pred_text - noise_pred_uncond

                    #     noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text
             
                
                # compute the previous noisy sample x_t -> x_t-1
                next_ts = timesteps[i+1] if i < len(timesteps) - 1 else -1
                latents = self.scheduler.step(
                    noise_pred, t, next_ts, latents, **extra_step_kwargs
                ).prev_sample
                if guided_iters < cal_gradients_iters:
                    prev_latents.append(latents.detach().cpu())

                # call the callback, if provided
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                ):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents=None,
        )

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = (
                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                )
                latent_model_input = self.scheduler.scale_model_input(
                    latent_model_input, t
                )

                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                    noise_pred_text = noise_pred_text - noise_pred_uncond

                    noise_pred = noise_pred_uncond + guidance_scale * noise_pred_text

                # compute the previous noisy sample x_t -> x_t-1
                next_ts = timesteps[i+1] if i < len(timesteps) - 1 else -1
                latents = self.scheduler.step(
                    noise_pred, t, next_ts, latents, **extra_step_kwargs
                ).prev_sample

                # call the callback, if provided
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                ):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)


        if not output_type == "latent":
            image = self.vae.decode(
                latents / self.vae.config.scaling_factor, return_dict=False
            )[0]
            image, has_nsfw_concept = self.run_safety_checker(
                image, device, prompt_embeds.dtype
            )
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(
            image.detach(), output_type=output_type, do_denormalize=do_denormalize
        )

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(
            images=image, nsfw_content_detected=has_nsfw_concept
        )

