import time
import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.utils import scale_lora_layers, USE_PEFT_BACKEND, logging, unscale_lora_layers, replace_example_docstring, deprecate
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps, rescale_noise_cfg
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from torch.func import jvp 
from contextlib import contextmanager
torch.set_printoptions(precision=5, sci_mode=False)

def _jvp_mode(flag: bool, device: torch.device):
    """
    Flags that need to be set for jvp to work with attention layers.

    NOTE: This has been tested on torch version 2.1.1, hopefully,
    this issue will be resolved in a future version of torch
    as jvp mode reduces the speed of JVP computation.
    """
    if device.type == "cuda":
        torch.backends.cuda.enable_flash_sdp(not flag)
        torch.backends.cuda.enable_mem_efficient_sdp(not flag)
        # torch.backends.cuda.enable_math_sdp(flag)


@contextmanager
def _jvp_mode_enabled(device: torch.device):
    _jvp_mode(True, device)
    try:
        yield
    finally:
        _jvp_mode(False, device)

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt).images[0]
        ```
"""

class LocalStableDiffusionPipeline(StableDiffusionPipeline):

    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
    _exclude_from_cpu_offload = ["safety_checker"]
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        image_encoder: CLIPVisionModelWithProjection = None,
        requires_safety_checker: bool = True,
    ):
        super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
                         scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, requires_safety_checker=requires_safety_checker)

    def encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
        args = None,
    ):
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            lora_scale (`float`, *optional*):
                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        """

        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
            self._lora_scale = lora_scale

            # dynamically adjust the LoRA scale
            if not USE_PEFT_BACKEND:
                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
            else:
                scale_lora_layers(self.text_encoder, lora_scale)

        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            # textual inversion: procecss multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )

            temp_inputs_ids = self.tokenizer(
                prompt,
                # padding="max_length", # This is only for print. It will not be used for generation (RJ)
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )

            # print(temp_inputs_ids['input_ids'].shape)
            if args is not None:
                args.prompt_length = temp_inputs_ids['input_ids'].shape[1]

            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None
                # attention_mask = text_inputs.attention_mask.to(device)

            # import pdb ; pdb.set_trace()

            if clip_skip is None:
                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
                prompt_embeds = prompt_embeds[0]
            else:
                prompt_embeds = self.text_encoder(
                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
                )
                # Access the `hidden_states` first, that contains a tuple of
                # all the hidden states from the encoder layers. Then index into
                # the tuple to access the hidden states from the desired layer.
                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
                # We also need to apply the final LayerNorm here to not mess with the
                # representations. The `last_hidden_states` that we typically use for
                # obtaining the final prompt representations passes through the LayerNorm
                # layer.
                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)

        if self.text_encoder is not None:
            prompt_embeds_dtype = self.text_encoder.dtype
        elif self.unet is not None:
            prompt_embeds_dtype = self.unet.dtype
        else:
            prompt_embeds_dtype = prompt_embeds.dtype

        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif prompt is not None and type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            # textual inversion: procecss multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
            # Retrieve the original scale by scaling back the LoRA layers
            unscale_lora_layers(self.text_encoder, lora_scale)

        return prompt_embeds, negative_prompt_embeds
    
    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        timesteps: List[int] = None,
        guidance_scale: float = 7.5,
        mode="x,c|x",
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        clip_skip: Optional[int] = None,
        normalization: Optional[str] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        save_prefix="heatmap/test",
        args = None,
        **kwargs,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`List[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            guidance_rescale (`float`, *optional*, defaults to 0.0):
                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
                using zero terminal SNR.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        """

        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )

        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
        # to deal with lora scaling and other possible forward hooks

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            height,
            width,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._guidance_rescale = guidance_rescale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs
        self._interrupt = False
        self.mode=mode

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # 3. Encode input prompt
        lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            clip_skip=self.clip_skip,
            args=args,
        )

        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        if ip_adapter_image is not None:
            image_embeds = self.prepare_ip_adapter_image_embeds(
                ip_adapter_image, device, batch_size * num_images_per_prompt
            )

        # 4. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        
        sd_ver, exp_type = args.sd_ver, args.exp_type
        
        if exp_type == 'det':
            hvp_sampling_num = args.hvp_sampling_num
            assert hvp_sampling_num <= num_inference_steps, 'hvp sampling should not exceed total inference steps'
            hvp_cont = torch.zeros(hvp_sampling_num)
            cosine_cont = torch.zeros(hvp_sampling_num)

        ipp = num_images_per_prompt

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 6.1 Add image embeds for IP-Adapter
        added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

        # 6.2 Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)

        ######## MITIGATION CODE #################################################################################
        if exp_type == 'miti':                 
            thres, seed = args.miti_thres, args.gen_seed
            
            with torch.enable_grad():
                p_uncond, p_cond = prompt_embeds[0].unsqueeze(dim=0), prompt_embeds[ipp].unsqueeze(dim=0)    
                p_tot = torch.cat([p_uncond.detach()] * args.miti_budget + [p_cond.detach()] * args.miti_budget)
                
                t = timesteps[0]
                beta_prod = torch.sqrt(1 - self.scheduler.alphas_cumprod[t])
                beta = torch.sqrt(self.scheduler.alphas_cumprod[t])

                lat_lst = []
                counter = 0 #counter for updating threshold
                while True: 
                    torch.manual_seed(seed)
                    lat = torch.randn((args.miti_budget, *latents.shape[1:]), device=latents.device, requires_grad=True)
                    lat_out = self.scheduler.scale_model_input(lat, t)
                    lat_single = torch.cat([lat_out] * 2) 
                    optimizer = torch.optim.Adam([lat], lr=args.miti_lr)

                    step_cnt = 0
                    indice_record = torch.tensor([], device=lat.device, dtype=torch.long) 
                    while step_cnt < args.miti_max_steps:
                        
                        noise = self.unet(lat_single, t, encoder_hidden_states=p_tot)[0]     
                        uc_pred, c_pred = noise.chunk(2) if sd_ver == 1 else (-(beta * noise - lat_single) / beta_prod).chunk(2)
                        diff_pred = (c_pred - uc_pred)
                        diff_pred = diff_pred / diff_pred.view(args.miti_budget, -1).norm(dim=1)[ :, None, None, None]
                        
                        lat_modi = torch.cat([lat + diff_pred]*2)
                        noise_modi = self.unet(lat_modi, t, encoder_hidden_states=p_tot)
                        
                        #calculate proxy for SAIL loss (refer to Algorithm2 of https://arxiv.org/pdf/2412.04140)
                        if sd_ver == 1:
                            uc_modi, c_modi = noise_modi[0].chunk(2)
                            hvp_loss = (c_modi - uc_modi).view(args.miti_budget, -1).norm(dim=1) 
                            gaussianity = lat.view(args.miti_budget, -1).norm(dim=1)
                            loss = hvp_loss + 0.05 * gaussianity
                        else:
                            uc_modi, c_modi = (- (beta * noise_modi[0] - lat_modi) / beta_prod).chunk(2)
                            hvp_loss = (c_modi - uc_modi).view(args.miti_budget, -1).norm(dim=1)
                            gaussianity = lat.view(args.miti_budget, -1).norm(dim=1)
                            loss = hvp_loss + 0.01 * gaussianity
                        
                        indices = torch.where(loss <= thres)[0] 
                        updated_indices = torch.cat([indice_record, indices]).unique()
                        
                        if len(updated_indices) > len(indice_record):
                            new_indices = updated_indices[~torch.isin(updated_indices, indice_record)]
                            lat_lst.extend([lat[i].detach().unsqueeze(0) for i in new_indices])
                            indice_record = updated_indices
                            print(f"{len(new_indices)} Latents Added")
                            
                            if len(lat_lst) >= ipp:
                                break
                        
                        loss = loss.sum()
                        loss.backward()
                        optimizer.step()
                        optimizer.zero_grad()
                        step_cnt += 1
                        
                    if len(lat_lst) >= ipp:
                        break                                                
                    else:
                        counter += 1
                        
                    #This code is to increase l_thres & change 'seed' to prevent model
                    #when model is not able to find satisfiable latents
                    if counter//3 != (counter-1)//3:
                        thres += 0.1
                        seed = torch.randint(0, 50000, (1,)).item()
                        
                        print(f'Thres Updated: {thres:.2f}')
                        print(f'Seed Changed: {seed}')

                torch.cuda.empty_cache()
                latents = torch.cat(lat_lst)[:ipp]
        #################################################################################################################


        with self.progress_bar(total=num_inference_steps) as progress_bar:
            start_time=time.time()
            for i, t in enumerate(timesteps):
                t_cosine=timesteps[-3]
                print(t_cosine)
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                latent_model_input_cosine = self.scheduler.scale_model_input(latent_model_input, t_cosine)

                noise_pred_dict = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        return_dict=False,
                    )
                noise_pred_dict_cosine = self.unet(
                    latent_model_input_cosine,
                    t_cosine,
                    encoder_hidden_states=prompt_embeds,
                    return_dict=False,
                )

                ######## DETECTION CODE #################################################################################
                if args.exp_type == 'det':
                    if i >= hvp_sampling_num:
                        end_time=time.time()
                        print(f"Avg step time: {(end_time-start_time):.2f} seconds")
                        return hvp_cont , cosine_cont

                    beta_prod = torch.sqrt(1 - self.scheduler.alphas_cumprod[t])
                    beta = torch.sqrt(self.scheduler.alphas_cumprod[t])
                    uc_eps, c_eps = noise_pred_dict[0].chunk(2)
                    beta_prod_cosine = torch.sqrt(1 - self.scheduler.alphas_cumprod[t_cosine])
                    beta_cosine = torch.sqrt(self.scheduler.alphas_cumprod[t_cosine])
                    uc_eps_cosine, c_eps_cosine = noise_pred_dict_cosine[0].chunk(2) 
                    if normalization=="L2":
                        print("Using L2 layer normalized scores")
                        #uc_eps_cosine = uc_eps_cosine / uc_eps_cosine.view(ipp, -1).norm(dim=1)[:, None, None, None]
                        #c_eps_cosine = c_eps_cosine / c_eps_cosine.view(ipp, -1).norm(dim=1)[:, None, None, None]
                        #uc_eps=uc_eps / uc_eps.view(ipp, -1).norm(dim=1)[:, None, None, None]
                        #c_eps=c_eps / c_eps.view(ipp, -1).norm(dim=1)[:, None, None, None]
                        uc_eps_cosine = uc_eps_cosine / (uc_eps_cosine.norm(dim=(2,3), keepdim=True) + 1e-6)
                        c_eps_cosine = c_eps_cosine / (c_eps_cosine.norm(dim=(2,3), keepdim=True) + 1e-6)
                        uc_eps = uc_eps / (uc_eps.norm(dim=(2,3), keepdim=True) + 1e-6)
                        c_eps = c_eps / (c_eps.norm(dim=(2,3), keepdim=True) + 1e-6)
                    elif normalization=="L1":
                        print("Using L1 normalized scores")
                        uc_eps_cosine = uc_eps_cosine / uc_eps_cosine.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                        c_eps_cosine = c_eps_cosine / c_eps_cosine.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                        uc_eps=uc_eps / uc_eps.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                        c_eps=c_eps / c_eps.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                    else:
                        pass
                            
                    if sd_ver == 1:
                        eps_diff = (c_eps - uc_eps)
                        diff_score = - eps_diff / beta_prod
                        eps_diff_cosine = (c_eps_cosine - uc_eps_cosine)
                        diff_score_cosine = - eps_diff_cosine / beta_prod_cosine

                        def score_func (x):
                            eps_pred = self.unet(
                                torch.cat([x, x]),
                                t,
                                encoder_hidden_states=prompt_embeds, 
                                return_dict=False,
                            )[0]
                            return - eps_pred / beta_prod
                                    
                        def hvp_comb(vec):
                            vec = vec.view_as(latent_model_input[:ipp])
                            _, vjp_result = jvp(score_func, (latent_model_input[:ipp],), (vec,))
                            return vjp_result

                        with _jvp_mode_enabled(noise_pred_dict[0].device), torch.amp.autocast("cuda", dtype=torch.float32):
                            #eps_uc, eps_c = hvp_comb(diff_score).chunk(2)
                            #comb_hvp = eps_c - eps_uc
                            comb_hvp=eps_diff
                        

                        ######COSINE########
                        if args.mode=="x,c|x":
                            cosine_sim=torch.nn.functional.cosine_similarity(uc_eps_cosine, (c_eps_cosine - uc_eps_cosine), dim=1)
                        elif args.mode=="x,x|c":
                            cosine_sim=torch.nn.functional.cosine_similarity(c_eps_cosine, uc_eps_cosine, dim=1)
                        elif args.mode=="x|c,c|x":
                            cosine_sim=torch.nn.functional.cosine_similarity(c_eps_cosine, (c_eps_cosine - uc_eps_cosine), dim=1)
                        else:
                            raise NotImplementedError

                    else: 
                        score_uc, score_c = ((beta * noise_pred_dict[0] - latent_model_input)/(beta_prod**2)).chunk(2)
                        diff_score = (score_c - score_uc)
                        score_uc_cosine, score_c_cosine = ((beta_cosine * noise_pred_dict_cosine[0] - latent_model_input_cosine)/(beta_prod_cosine**2)).chunk(2)
                        
                        if normalization=="L2":
                            print("Using L2 layer normalized scores")
                            #score_uc_cosine = score_uc_cosine / uc_eps_cosine.view(ipp, -1).norm(dim=1)[:, None, None, None]
                            #score_c_cosine = score_c_cosine / c_eps_cosine.view(ipp, -1).norm(dim=1)[:, None, None, None]
                            score_uc_cosine = score_uc_cosine / (score_uc_cosine.norm(dim=(2,3), keepdim=True) + 1e-6)
                            score_c_cosine = score_c_cosine / (score_c_cosine.norm(dim=(2,3), keepdim=True) + 1e-6)

                        elif normalization=="L1":
                            print("Using L1 normalized scores")
                            score_uc_cosine = score_uc_cosine / uc_eps_cosine.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                            score_c_cosine = score_c_cosine / c_eps_cosine.view(ipp, -1).abs().sum(dim=1)[:, None, None, None]
                        else:
                            pass
                        diff_score_cosine = (score_c_cosine - score_uc_cosine)

                        def score_func (x):
                            v_pred = self.unet(
                                torch.cat([x, x]), 
                                t,
                                encoder_hidden_states=prompt_embeds, 
                                return_dict=False,
                            )[0]
                            return (beta * v_pred - latent_model_input) / (beta_prod**2)
                        
                        def hvp_comb (vec):
                            vec = vec.view_as(latent_model_input[:ipp])
                            _, jvp_res = jvp(score_func, (latent_model_input[:ipp],), (vec,))
                            return jvp_res

                        with _jvp_mode_enabled(noise_pred_dict[0].device), torch.amp.autocast("cuda", dtype=torch.float32):
                            #h_uc, h_c = hvp_comb(diff_score).chunk(2)
                            #comb_hvp = h_c - h_uc
                            comb_hvp=diff_score

                        #####COSINE#####
                        uc_eps_cosine, c_eps_cosine = noise_pred_dict_cosine[0].chunk(2)
                        diff_score_cosine = (c_eps_cosine - uc_eps_cosine)

                        if args.mode=="x,c|x":
                            cosine_sim=torch.nn.functional.cosine_similarity(uc_eps_cosine, diff_score_cosine, dim=1)
                        elif args.mode=="x,x|c":
                            cosine_sim=torch.nn.functional.cosine_similarity(uc_eps_cosine, c_eps_cosine,  dim=1)
                        elif args.mode=="x|c,c|x":
                            cosine_sim=torch.nn.functional.cosine_similarity(c_eps_cosine, diff_score_cosine, dim=1)
                        else:
                            raise NotImplementedError

                    hvp_metric = torch.norm(comb_hvp.view(ipp, -1), dim=1)
                    cosine_sim_gens=[]
                    #print(cosine_sim.shape)
                    for j in cosine_sim:
                        cosine_sim_gens.append(j.mean())

                    hvp_cont[i] = hvp_metric.mean().item()
                    print("cosine",cosine_sim_gens)
                    cosine_cont[i]=torch.tensor(cosine_sim_gens).mean().item()
                #############################################################################################################
                noise_pred = noise_pred_dict[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

        if exp_type in ['orig', 'miti']:
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

            if has_nsfw_concept is None:
                do_denormalize = [True] * image.shape[0]
            else:
                do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

            image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
            images = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
            return images

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                0
            ]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)