import math
from os import pipe
from typing import Callable, Dict, List, Optional, Tuple,Union

import numpy as np
import PIL
import torch
import cv2
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel,LMSDiscreteScheduler
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
import torchvision.transforms as T
from ptp_utils import AttentionStore,aggregate_attention,view_images,text_under_image
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
import json
from gaussian_smoothing import GaussianSmoothing

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
def always_round(x):
    intx = int(x)
    is_even = intx%2 == 0
    if is_even:
        if x < intx + 0.5:
            return intx
        return intx + 1
    else:
        return round(x)

def preprocess(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0


def _img_importance_flatten(img: torch.tensor, w: int, h: int) -> torch.tensor:
    return F.interpolate(
        img.unsqueeze(0).unsqueeze(1),
        # scale_factor=1 / ratio,
        size=(w, h),
        mode="bilinear",
        align_corners=True,
    ).squeeze()


def _pil_from_latents(vae, latents):
    _latents = 1 / 0.18215 * latents.clone()
    image = vae.decode(_latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()

    return image


def _image_context_seperator(
    img: Image.Image, color_context: dict, _tokenizer,neg:float
) -> List[Tuple[List[int], torch.Tensor]]:

    ret_lists = []

    if img is not None:
        w, h = img.size
        matrix = np.zeros((768, 768))
        for color, v in color_context.items():
            color=tuple(color)
            if len(color)>3:
                color=color[:3]
            if isinstance(color, str):
                r, g, b = color[1:3], color[3:5], color[5:7]
                color = (int(r, 16), int(g, 16), int(b, 16))
            img_where_color = (np.array(img) == color).all(axis=-1)
            matrix[img_where_color]=1

        for color, v in color_context.items():
            if len(color)>3:
                color=color[:3]
            f = v.split(",")[-1]
            v = ",".join(v.split(",")[:-1])
            f = float(f)
            print(v)
            v_input = _tokenizer(
                v,
                max_length=_tokenizer.model_max_length,
                truncation=True,
            )
            v_as_tokens = v_input["input_ids"][1:-1]
            if isinstance(color, str):
                r, g, b = color[1:3], color[3:5], color[5:7]
                color = (int(r, 16), int(g, 16), int(b, 16))
            img_where_color = (np.array(img) == color).all(axis=-1)
            matrix[img_where_color]=1
            if not img_where_color.sum() > 0:
                print(f"Warning : not a single color {color} not found in image")
            img_where_color_init=torch.where(torch.tensor(img_where_color,dtype=torch.bool), f, neg)

            img_where_color = torch.where(torch.from_numpy(matrix == 1) & (img_where_color_init == 0.0), torch.tensor(neg), img_where_color_init)

            ret_lists.append((v_as_tokens, img_where_color))
    else:
        w, h = 512,512

    if len(ret_lists) == 0:
        ret_lists.append(([-1], torch.zeros((w, h), dtype=torch.float32)))
    return ret_lists, w, h


def _tokens_img_attention_weight(
    img_context_seperated, tokenized_texts, ratio: int = 8, original_shape=False
):
    
    token_lis = tokenized_texts["input_ids"][0].tolist()
    w, h = img_context_seperated[0][1].shape

    w_r, h_r = always_round(w/ratio), always_round(h/ratio)
    ret_tensor = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32)
    for v_as_tokens, img_where_color in img_context_seperated:
        is_in = 0
        
        for idx, tok in enumerate(token_lis):
            if token_lis[idx : idx + len(v_as_tokens)] == v_as_tokens:
                is_in = 1
                # if len(torch.where(ret_tensor[:, idx : idx + len(v_as_tokens)]>0)[0])>0:
                #     continue
                # print(token_lis[idx : idx + len(v_as_tokens)], v_as_tokens)
                ret_tensor[:, idx : idx + len(v_as_tokens)] += (
                    _img_importance_flatten(img_where_color, w_r, h_r)
                    .reshape(-1, 1)
                    .repeat(1, len(v_as_tokens))
                )
        if not is_in == 1:
            print(f"Warning ratio {ratio} : tokens {v_as_tokens} not found in text")

    if original_shape:
        ret_tensor = ret_tensor.reshape((w_r, h_r, len(token_lis)))
    return ret_tensor


def _extract_seed_and_sigma_from_context(color_context, ignore_seed = -1):
    # Split seed and sigma from color_context if provided
    extra_seeds = {}
    extra_sigmas = {}
    for i, (k, _context) in enumerate(color_context.items()):
        _context_split = _context.split(',')
        if len(_context_split) > 2:
            try:
                seed = int(_context_split[-2])
                sigma = float(_context_split[-1])
                _context_split = _context_split[:-2]
                extra_sigmas[i] = sigma
            except ValueError:
                seed = int(_context_split[-1])
                _context_split = _context_split[:-1]
            if seed != ignore_seed:
                extra_seeds[i] = seed
      
        color_context[k] = ','.join(_context_split)
    return color_context, extra_seeds, extra_sigmas


def _get_binary_mask(seperated_word_contexts, extra_seeds, dtype, size):
    img_where_color_mask = [(seperated_word_contexts[k][1] > 0).type(dtype) for k in extra_seeds.keys()]
    img_where_color_mask = [F.interpolate(mask.unsqueeze(0).unsqueeze(1), 
        size=size, mode='bilinear') for mask in img_where_color_mask]
    return img_where_color_mask


def _blur_image_mask(seperated_word_contexts, extra_sigmas):
    for k, sigma in extra_sigmas.items():
        blurrer = T.GaussianBlur(kernel_size=(39, 39), sigma=(sigma, sigma))
        v_as_tokens, img_where_color = seperated_word_contexts[k]
        seperated_word_contexts[k] = (v_as_tokens, blurrer(img_where_color[None,None])[0,0])
    return seperated_word_contexts
def construct_direction(embs_source: torch.Tensor, embs_target: torch.Tensor):
        """Constructs the edit direction to steer the image generation process semantically."""
        return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)

def _encode_text_color_inputs(
        text_encoder, tokenizer, device, 
        color_map_image, color_context, 
        input_prompt, unconditional_input_prompt,neg,edit_list):
    # Process input prompt text
    text_input = tokenizer(
        [input_prompt],
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    color_context, extra_seeds, extra_sigmas = _extract_seed_and_sigma_from_context(color_context)
    is_extra_sigma = len(extra_sigmas) > 0
    
    # Process color map image and context
    seperated_word_contexts, width, height = _image_context_seperator(
        color_map_image, color_context, tokenizer,neg
    )

    if is_extra_sigma:
        print('Use extra sigma to smooth mask', extra_sigmas)
        seperated_word_contexts = _blur_image_mask(seperated_word_contexts, extra_sigmas)
    

    cross_attention_weight_1 = _tokens_img_attention_weight(
        seperated_word_contexts, text_input, ratio=1, original_shape=True
    ).to(device)
    cross_attention_weight_8 = _tokens_img_attention_weight(
        seperated_word_contexts, text_input, ratio=8
    ).to(device)
    cross_attention_weight_16 = _tokens_img_attention_weight(
        seperated_word_contexts, text_input, ratio=16
    ).to(device)
    cross_attention_weight_32 = _tokens_img_attention_weight(
        seperated_word_contexts, text_input, ratio=32
    ).to(device)
    cross_attention_weight_64 = _tokens_img_attention_weight(
        seperated_word_contexts, text_input, ratio=64
    ).to(device)

    cond_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    embedding_direction={
    "butterfly":edit_embed_butterfly,
    "candle":edit_embed_candle,
    "cat":edit_embed_cat2,
    "dog":edit_embed_dog,
    "flower":edit_embed_flower,
    "duck toy":edit_embed_duck,
    "teapot":edit_embed_teapot,
    "mug":edit_embed_mug,
    "chair":edit_embed_chair,
    "lake":edit_embed_lake,
    "carbin":edit_embed_dog7_2,
    "chow chow":edit_embed_dog2_2,
    "sunglasses":edit_embed_sunglasses2,
    "barn":edit_embed_barn,
    "glasses":edit_embed_glasses,
    "puppy":edit_embed_puppy2,
    "hat":edit_embed_hat,
    "dog3":edit_embed_dog3,
    "monster toy":edit_embed_monster_toy,
    "party hat": edit_embed_party_hat,
    "puppy3":edit_embed_puppy3
    }

    for name,token in edit_list:
        try:
            cond_embeddings[0][token]+=embedding_direction[name].reshape(1024)

        except:
            cond_embeddings[0][token:token+2]+=embedding_direction[name]




    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [unconditional_input_prompt],
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

    encoder_hidden_states = {
        "CONTEXT_TENSOR": cond_embeddings,
        f"CROSS_ATTENTION_WEIGHT_ORIG": cross_attention_weight_1,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/8)*always_round(width/8)}": cross_attention_weight_8,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/16)*always_round(width/16)}": cross_attention_weight_16,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/32)*always_round(width/32)}": cross_attention_weight_32,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/64)*always_round(width/64)}": cross_attention_weight_64,
    }

    uncond_encoder_hidden_states = {
        "CONTEXT_TENSOR": uncond_embeddings,
        f"CROSS_ATTENTION_WEIGHT_ORIG": 0,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/8)*always_round(width/8)}": 0,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/16)*always_round(width/16)}": 0,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/32)*always_round(width/32)}": 0,
        f"CROSS_ATTENTION_WEIGHT_{always_round(height/64)*always_round(width/64)}": 0,
    }

    return extra_seeds, seperated_word_contexts, encoder_hidden_states, uncond_encoder_hidden_states
def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images
def run_safety_checker(image, device, dtype):

    has_nsfw_concept = None
    return image, has_nsfw_concept
def compute_max_attention_per_index(     tokenizer,
                                        encoder_hidden_states,
                                         input_prompt,
                                         attention_maps: torch.Tensor,
                                         indices_to_alter: List[int],
                                         smooth_attentions: bool = False,
                                         sigma: float = 0.5,
                                         kernel_size: int = 3,
                                         normalize_eot: bool = False,
                                         ) -> List[torch.Tensor]:
        """ Computes the maximum attention value for each of the tokens we wish to alter. """
        last_idx = -1
        if normalize_eot:
            last_idx = len(tokenizer(input_prompt)['input_ids']) - 1
            print(last_idx)
        attention_for_text = attention_maps[:, :, 1:last_idx]
        attention_for_text *= 100
        attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)

        indices_to_alter = [index - 1 for index in indices_to_alter]

        max_indices_list = []
        for i in indices_to_alter:
            image = attention_for_text[:, :, i]
            if smooth_attentions:
                smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
                input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
                image = smoothing(input).squeeze(0).squeeze(0)
           #max_indices_list.append(image.max())
            #max_indices_list.append(image.std())
           # max_indices_list.append(torch.max(image[encoder_hidden_states['CROSS_ATTENTION_WEIGHT_576'][:,i+1].reshape(24,24)>0]))  #std()
            max_indices_list.append(torch.max(image[encoder_hidden_states['CROSS_ATTENTION_WEIGHT_576'][:,i+1].reshape(24,24)<0]))
           # max_indices_list.append(torch.mean(torch.topk(image[encoder_hidden_states['CROSS_ATTENTION_WEIGHT_576'][:,i+1].reshape(24,24)>0],int(0.5*len(image[encoder_hidden_states['CROSS_ATTENTION_WEIGHT_576'][:,i+1].reshape(24,24)>0])),dim=0,largest=True,sorted=True)))
        return max_indices_list

def aggregate_and_get_max_attention_per_token(tokenizer,
                                                   encoder_hidden_states,
                                                   input_prompt,
                                                   attention_store: AttentionStore,
                                                   indices_to_alter: List[int],
                                                   attention_res: int = 16,
                                                   smooth_attentions: bool = False,
                                                   sigma: float = 0.5,
                                                   kernel_size: int = 3,
                                                   normalize_eot: bool = False,
                                                   ):
        """ Aggregates the attention for each token and computes the max activation value for each token to alter. """
        attention_maps = aggregate_attention(
            attention_store=attention_store,
            res=attention_res,
            from_where=("up", "down", "mid"),
            is_cross=True,
            select=0)
        max_attention_per_index = compute_max_attention_per_index(
            tokenizer=tokenizer,
            encoder_hidden_states=encoder_hidden_states,
            input_prompt=input_prompt,
            attention_maps=attention_maps,
            indices_to_alter=indices_to_alter,
            smooth_attentions=smooth_attentions,
            sigma=sigma,
            kernel_size=kernel_size,
            normalize_eot=normalize_eot,
            )
        return max_attention_per_index

def compute_loss(max_attention_per_index: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
    """ Computes the attend-and-excite loss using the maximum attention value for each token. """
    losses = [max(0, curr_max) for curr_max in max_attention_per_index]
    loss = max(losses)
    if return_losses:
        return loss, losses
    else:
        return loss

# @staticmethod
def update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
    """ Update the latent according to the computed loss. """
    grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
    latents = latents - step_size * grad_cond
    return latents
def show_cross_attention(prompt: str,
                         attention_store: AttentionStore,
                         tokenizer,
                         indices_to_alter: List[int],
                         res: int,
                         from_where: List[str],
                         select: int = 0,
                         orig_image=None):
    tokens = tokenizer.encode(prompt)
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu()
    images = []

    # show spatial attention for indices of tokens to strengthen
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        if i in indices_to_alter:
            image = show_image_relevance(image, orig_image)
            image = image.astype(np.uint8)
            image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
            image = text_under_image(image, decoder(int(tokens[i])))
            images.append(image)

    view_images(np.stack(images, axis=0))


def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16):
    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam
    image = image[0].resize((relevnace_res ** 2, relevnace_res ** 2))
    image = np.array(image)

    image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
    image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
    image_relevance = image_relevance.cpu() # send it back to cpu
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis



@torch.no_grad()
def atten(
    color_context: Dict[Tuple[int, int, int], str] = {},
    color_map_image: Optional[Image.Image] = None,
    input_prompt: str = "",
    num_inference_steps: int = 30,
    guidance_scale: float = 7.5,
    seed: int = 0,
    scheduler_type = LMSDiscreteScheduler,
    device: str = "cuda:0",
    weight_function: Callable = lambda w, sigma, qk: 0.1
    * w
    * math.log(sigma + 1)
    * qk.max(),
    local_model_path: Optional[str] = None,
    hf_model_path: Optional[str] = "CompVis/stable-diffusion-v1-4",
    preloaded_utils: Optional[Tuple] = None,
    unconditional_input_prompt: str = "",
    model_token: Optional[str] = None,
    init_image: Optional[Image.Image] = None,
    strength: float = 0.5,
    pipe=pipe,
    aff: int =10,
    neg: float=-1.0,
    token: List=[],
    controller=AttentionStore,
    attend:int=25,
    fix_lr=20.0,
    edit_list: List=[]


):
    width, height = color_map_image.size
    vae=pipe.vae
    unet=pipe.unet
    text_encoder=pipe.text_encoder
    tokenizer=pipe.tokenizer
    scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
    #scheduler2 = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
    scheduler2 =LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )
    # for _module in unet.modules():
    #     if _module.__class__.__name__ == "CrossAttention":
    #         _module.__class__.__call__ = inj_forward
    extra_seeds, seperated_word_contexts, encoder_hidden_states, uncond_encoder_hidden_states = \
        _encode_text_color_inputs(text_encoder, tokenizer, device, color_map_image, color_context, input_prompt, unconditional_input_prompt,neg,edit_list)
    is_extra_seed = len(extra_seeds) > 0

    scheduler.set_timesteps(50, device=device)
    
    scheduler2.set_timesteps(aff, device=device)
    if init_image is None:
        timesteps = scheduler.timesteps
    else:
        offset = scheduler.config.get("steps_offset", 0)
        init_timestep = int(num_inference_steps * strength) + offset
        init_timestep = min(init_timestep, num_inference_steps)
        t_start = max(num_inference_steps - init_timestep + offset, 0)
        timesteps = scheduler.timesteps[t_start:]
        num_inference_steps = num_inference_steps - t_start
        latent_timestep = timesteps[:1]
    if init_image is None: 
        latent_size = (1, unet.in_channels, height // 8, width // 8)
        latents = torch.randn(latent_size, generator=torch.manual_seed(seed))
        if is_extra_seed:
            print('Use region based seeding: ', extra_seeds)
            multi_latents = [torch.randn(latent_size,
                generator=torch.manual_seed(_seed)) for _seed in extra_seeds.values()]
            img_where_color_mask = _get_binary_mask(seperated_word_contexts, extra_seeds, dtype=latents[0].dtype, size=latent_size[-2:])
            foreground = (sum(img_where_color_mask) > 0).squeeze()
            summed_multi_latents = sum(_latents * _mask for _latents, _mask in zip(multi_latents, img_where_color_mask))
            latents[:,:,foreground] = summed_multi_latents[:,:,foreground]
        latents = latents.to(device)
        print(device)
        print(scheduler.init_noise_sigma)
        latents = latents * scheduler.init_noise_sigma 
    else:
        init_image = preprocess(init_image)
        image = init_image.to(device=device)
        init_latent_dist = vae.encode(image).latent_dist
        init_latents = init_latent_dist.sample()
        init_latents = 0.18215 * init_latents
        noise = torch.randn(init_latents.shape).to(device)

        # get latents
        init_latents = scheduler.add_noise(init_latents, noise, latent_timestep)
        latents = init_latents
    
    encoder_hidden_states_25=encoder_hidden_states["CONTEXT_TENSOR"]
    uncond_encoder_hidden_states_25=uncond_encoder_hidden_states["CONTEXT_TENSOR"]

    scale_range = np.linspace(1.0,0.5, 50)
    for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
        # sigma for pww
        if i<aff:
            loop=2
        else:
            loop=1
        for k in range(loop):
            if i<aff:
                step_index = (scheduler.timesteps == t).nonzero().item()
                sigma = scheduler2.sigmas[i]
                latent_model_input = scheduler.scale_model_input(latents, t)
                _t = t 
                encoder_hidden_states.update({
                        "SIGMA": sigma,
                        "WEIGHT_FUNCTION": weight_function,
                    })
                with torch.enable_grad():
                    latent_model_input = latent_model_input.clone().detach().requires_grad_(True)
                    _t = t 
                    noise_pred_text = unet(
                        latent_model_input,
                        _t,
                        encoder_hidden_states=encoder_hidden_states,
                    ).sample
                    unet.zero_grad()
                    max_attention_per_index=aggregate_and_get_max_attention_per_token(
                            tokenizer=tokenizer,
                            encoder_hidden_states=encoder_hidden_states
                            ,input_prompt=input_prompt,
                            attention_store=controller,
                            indices_to_alter=token,
                            attention_res=24,
                            smooth_attentions=True,
                            sigma=0.5,
                            kernel_size=3,
                            normalize_eot=True
                            )
                    loss = compute_loss(max_attention_per_index=max_attention_per_index)
                noise_pred_text = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=encoder_hidden_states,
                ).sample

                uncond_encoder_hidden_states.update({
                        "SIGMA": sigma,
                        "WEIGHT_FUNCTION": lambda w, sigma, qk: 0.0,
                    })
                
                noise_pred_uncond = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=uncond_encoder_hidden_states,
                ).sample

                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )
                latents = scheduler.step(noise_pred, t, latents,1).prev_sample
                if k<1 and loop>1:

                    noise_recurent = torch.randn(latents.shape).to(device)
                    latents =  latents+noise_recurent*((scheduler.sigmas[i]**2-scheduler.sigmas[i+1]**2)**0.5)
                
            elif i >=aff and i <attend:
                with torch.enable_grad():
                    latent_model_input = scheduler.scale_model_input(latents, t)
                    latent_model_input = latent_model_input.clone().detach().requires_grad_(True)
                    _t = t 
                    noise_pred_text = unet(
                        latent_model_input,
                        _t,
                        encoder_hidden_states=encoder_hidden_states_25,
                    ).sample
                    unet.zero_grad()
                    max_attention_per_index=aggregate_and_get_max_attention_per_token(
                            tokenizer=tokenizer,
                            encoder_hidden_states=encoder_hidden_states_25,
                            input_prompt=input_prompt,
                            attention_store=controller,
                            indices_to_alter=token,
                            attention_res=24,
                            smooth_attentions=True,
                            sigma=0.5,
                            kernel_size=3,
                            normalize_eot=True
                            )
                    loss = compute_loss(max_attention_per_index=max_attention_per_index)
                    if loss != 0:

                        latent_model_input = update_latent(latents=latent_model_input, loss=loss,
                                                                step_size=fix_lr * np.sqrt(scale_range[i]))# param change what happend
                        print(f'Iteration {i} | Loss: {loss:0.4f}')

                noise_pred_text = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=encoder_hidden_states_25,
                ).sample

                
                noise_pred_uncond = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=uncond_encoder_hidden_states_25,
                ).sample

                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )

                # compute the previous noisy sample x_t -> x_t-1
                latents = scheduler.step(noise_pred, t, latents,1).prev_sample
            else:
                latent_model_input = scheduler.scale_model_input(latents, t)
                
                _t = t 
                print(_t)
                noise_pred_text = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=encoder_hidden_states_25,
                ).sample

                latent_model_input = scheduler.scale_model_input(latents, t)
            

                
                noise_pred_uncond = unet(
                    latent_model_input,
                    _t,
                    encoder_hidden_states=uncond_encoder_hidden_states_25,
                ).sample

                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )

                latents = scheduler.step(noise_pred, t, latents,1).prev_sample

    edited_image = _pil_from_latents(vae, latents)
    edited_image, has_nsfw_concept = run_safety_checker(edited_image, device, encoder_hidden_states_25.dtype)
    edited_image = numpy_to_pil(edited_image)
    #show_cross_attention(input_prompt,controller,tokenizer,token, res=24, from_where=("up", "down"),orig_image=edited_image)
        # Offload last model to CPU
    

    return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept).images















