# This code is taken from diffusers. All our modifications are noted within ######
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
import copy

from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import QwenImageLoraLoaderMixin
from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
from diffusers.pipelines.qwenimage.pipeline_qwenimage import calculate_shift, retrieve_timesteps

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import QwenImagePipeline

        >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")
        >>> prompt = "A cat holding a sign that says hello world"
        >>> # Depending on the variant being used, the pipeline call will slightly vary.
        >>> # Refer to the pipeline documentation for more details.
        >>> image = pipe(prompt, num_inference_steps=50).images[0]
        >>> image.save("qwenimage.png")
        ```
"""
######
def combine_with_segmentation_maps(bounding_boxes, latents_reshaped, batch_size):
    for idx, bb in enumerate(bounding_boxes):
        bb_latents_reshaped = bb['latents'].permute(0, 2, 1)  # shape: [1, 64, 1024]
        bb_latents_reshaped = bb_latents_reshaped.reshape(batch_size, 64, 32, 32)
        mask = bb['mask']
        mask_reshaped = mask.reshape(1, 1, 32, 32)
        latents_reshaped = torch.where(mask_reshaped, bb_latents_reshaped, latents_reshaped)
    return latents_reshaped


def combine_with_bounding_boxes(bounding_boxes, latents_reshaped, batch_size):
    for bb in bounding_boxes:
        bb_latents_reshaped = bb['latents'].permute(0, 2, 1)  # shape: [1, 64, 1024]
        bb_latents_reshaped = bb_latents_reshaped.reshape(batch_size, 64, 32, 32)
        x_min = bb['x_min']
        x_max = bb['x_max'] + 1
        y_min = bb['y_min']
        y_max = bb['y_max'] + 1
        latents_reshaped[:, :, y_min:y_max, x_min:x_max] = bb_latents_reshaped[:, :, y_min:y_max, x_min:x_max]
    return latents_reshaped


def combine_latents(latents, bounding_boxes, use_segmentation_dict):
    batch_size = latents.shape[0]
    latents_reshaped = latents.permute(0, 2, 1)  # shape: [1, 64, 1024]
    latents_reshaped = latents_reshaped.reshape(batch_size, 64, 32, 32)

    # assume background first in bounding_boxes
    latents_reshaped = combine_with_bounding_boxes([bounding_boxes[0]], latents_reshaped, batch_size)

    if use_segmentation_dict:
        latents_reshaped = combine_with_segmentation_maps(bounding_boxes[1:], latents_reshaped, batch_size)
    else:
        latents_reshaped = combine_with_bounding_boxes(bounding_boxes[1:], latents_reshaped, batch_size)
    
    latents_reshaped = latents_reshaped.reshape(batch_size, 64, 1024)  # shape: [1, 64, 1024]
    latents = latents_reshaped.permute(0, 2, 1)  # shape: [1, 1024, 64]
    return latents

# this function is made from the original diffusers
def get_noise_pred(
    model, latents, timestep, guidance, prompt_embeds_mask, prompt_embeds, img_shapes, txt_seq_lens,
    negative_prompt_embeds_mask, negative_prompt_embeds, negative_txt_seq_lens, true_cfg_scale, t, 
    callback_on_step_end, callback_on_step_end_tensor_inputs, do_true_cfg, scheduler
    ):
    with model.transformer.cache_context("cond"):
        noise_pred = model.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance,
            encoder_hidden_states_mask=prompt_embeds_mask,
            encoder_hidden_states=prompt_embeds,
            img_shapes=img_shapes,
            txt_seq_lens=txt_seq_lens,
            attention_kwargs=model.attention_kwargs,
            return_dict=False,
        )[0]

    if do_true_cfg:
        with model.transformer.cache_context("uncond"):
            neg_noise_pred = model.transformer(
                hidden_states=latents,
                timestep=timestep / 1000,
                guidance=guidance,
                encoder_hidden_states_mask=negative_prompt_embeds_mask,
                encoder_hidden_states=negative_prompt_embeds,
                img_shapes=img_shapes,
                txt_seq_lens=negative_txt_seq_lens,
                attention_kwargs=model.attention_kwargs,
                return_dict=False,
            )[0]
        comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

        cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
        noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
        noise_pred = comb_pred * (cond_norm / noise_norm)

    # compute the previous noisy sample x_t -> x_t-1
    latents_dtype = latents.dtype
    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    if latents.dtype != latents_dtype:
        if torch.backends.mps.is_available():
            # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
            latents = latents.to(latents_dtype)

    if callback_on_step_end is not None:
        callback_kwargs = {}
        for k in callback_on_step_end_tensor_inputs:
            callback_kwargs[k] = locals()[k]
        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

        latents = callback_outputs.pop("latents", latents)
        prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

    return latents, prompt_embeds
######

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def modified_qwen(
    model,
    combine_timestep,
    bounding_boxes,
    use_segmentation_dict,
    prompt: Union[str, List[str]] = None,
    negative_prompt: Union[str, List[str]] = None,
    true_cfg_scale: float = 4.0,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    sigmas: Optional[List[float]] = None,
    guidance_scale: Optional[float] = None,
    num_images_per_prompt: int = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.Tensor] = None,
    prompt_embeds: Optional[torch.Tensor] = None,
    prompt_embeds_mask: Optional[torch.Tensor] = None,
    negative_prompt_embeds: Optional[torch.Tensor] = None,
    negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    attention_kwargs: Optional[Dict[str, Any]] = None,
    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 512,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
            not greater than `1`).
        true_cfg_scale (`float`, *optional*, defaults to 1.0):
            Guidance scale as defined in [Classifier-Free Diffusion
            Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
            of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
            setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
            generate images that are closely linked to the text `prompt`, usually at the expense of lower image
            quality.
        height (`int`, *optional*, defaults to model.unet.config.sample_size * model.vae_scale_factor):
            The height in pixels of the generated image. This is set to 1024 by default for the best results.
        width (`int`, *optional*, defaults to model.unet.config.sample_size * model.vae_scale_factor):
            The width in pixels of the generated image. This is set to 1024 by default for the best results.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        sigmas (`List[float]`, *optional*):
            Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
            their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
            will be used.
        guidance_scale (`float`, *optional*, defaults to None):
            A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
            where the guidance scale is applied during inference through noise prediction rescaling, guidance
            distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
            scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
            that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
            parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
            ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
            please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
            enable classifier-free guidance computations).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        latents (`torch.Tensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will be generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.Tensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.Tensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
        attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `model.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        callback_on_step_end (`Callable`, *optional*):
            A function that calls at the end of each denoising steps during the inference. The function is called
            with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
            callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
            `callback_on_step_end_tensor_inputs`.
        callback_on_step_end_tensor_inputs (`List`, *optional*):
            The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
            will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
            `._callback_tensor_inputs` attribute of your pipeline class.
        max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.

    Examples:

    Returns:
        [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
        [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
        returning a tuple, the first element is a list with the generated images.
    """

    height = height or model.default_sample_size * model.vae_scale_factor
    width = width or model.default_sample_size * model.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    model.check_inputs(
        prompt,
        height,
        width,
        negative_prompt=negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        prompt_embeds_mask=prompt_embeds_mask,
        negative_prompt_embeds_mask=negative_prompt_embeds_mask,
        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        max_sequence_length=max_sequence_length,
    )

    model._guidance_scale = guidance_scale
    model._attention_kwargs = attention_kwargs
    model._current_timestep = None
    model._interrupt = False

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = model._execution_device

    has_neg_prompt = negative_prompt is not None or (
        negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
    )

    if true_cfg_scale > 1 and not has_neg_prompt:
        logger.warning(
            f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
        )
    elif true_cfg_scale <= 1 and has_neg_prompt:
        logger.warning(
            " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
        )

    do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    prompt_embeds, prompt_embeds_mask = model.encode_prompt(
        prompt=prompt,
        prompt_embeds=prompt_embeds,
        prompt_embeds_mask=prompt_embeds_mask,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=max_sequence_length,
    )
    ######
    # get a prompt embedding for every object
    for bb in bounding_boxes:
        (
            bb_prompt_embeds,
            bb_prompt_embeds_mask,
        ) = model.encode_prompt(
            prompt=bb['prompt'],
            prompt_embeds=None,
            prompt_embeds_mask=None,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
        )
        bb['prompt_embeds'] = bb_prompt_embeds
        bb['prompt_embeds_mask'] = bb_prompt_embeds_mask
    ######
    if do_true_cfg:
        negative_prompt_embeds, negative_prompt_embeds_mask = model.encode_prompt(
            prompt=negative_prompt,
            prompt_embeds=negative_prompt_embeds,
            prompt_embeds_mask=negative_prompt_embeds_mask,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
        )

    # 4. Prepare latent variables
    num_channels_latents = model.transformer.config.in_channels // 4
    latents = model.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )
    ######
    # copy the latents for each object
    for bb in bounding_boxes:
        bb['latents'] = latents.clone()
    ######
    img_shapes = [[(1, height // model.vae_scale_factor // 2, width // model.vae_scale_factor // 2)]] * batch_size

    # 5. Prepare timesteps
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        model.scheduler.config.get("base_image_seq_len", 256),
        model.scheduler.config.get("max_image_seq_len", 4096),
        model.scheduler.config.get("base_shift", 0.5),
        model.scheduler.config.get("max_shift", 1.15),
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        model.scheduler,
        num_inference_steps,
        device,
        sigmas=sigmas,
        mu=mu,
    )
    num_warmup_steps = max(len(timesteps) - num_inference_steps * model.scheduler.order, 0)
    model._num_timesteps = len(timesteps)

    # handle guidance
    if model.transformer.config.guidance_embeds and guidance_scale is None:
        raise ValueError("guidance_scale is required for guidance-distilled model.")
    elif model.transformer.config.guidance_embeds:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
        guidance = guidance.expand(latents.shape[0])
    elif not model.transformer.config.guidance_embeds and guidance_scale is not None:
        logger.warning(
            f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
        )
        guidance = None
    elif not model.transformer.config.guidance_embeds and guidance_scale is None:
        guidance = None

    if model.attention_kwargs is None:
        model._attention_kwargs = {}

    txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
    ######
    # get txt_seq_lens for each sub-prompt
    for bb in bounding_boxes:
        bb['txt_seq_lens'] = bb['prompt_embeds_mask'].sum(dim=1).tolist() if bb['prompt_embeds_mask'] is not None else None
    ######
    negative_txt_seq_lens = (
        negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
    )

    # 6. Denoising loop
    model.scheduler.set_begin_index(0)
    ######
    # copy the scheduler for each object
    for bb in bounding_boxes:
        bb['scheduler'] = copy.deepcopy(model.scheduler)
    ######
    with model.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if model.interrupt:
                continue

            model._current_timestep = t
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timestep = t.expand(latents.shape[0]).to(latents.dtype)

            ######
            if i == combine_timestep:
                latents = combine_latents(latents, bounding_boxes, use_segmentation_dict)
                # give up-to-date information to the transformer blocks
                for block in model.transformer.transformer_blocks:
                    block.attn.processor.bb = None
                model.scheduler = bounding_boxes[0]['scheduler']

            if i < combine_timestep:
                for j, bb in enumerate(bounding_boxes):
                    # give object-level information to the transformer blocks
                    for block in model.transformer.transformer_blocks:
                        block.attn.processor.bb = bb
                        if i == combine_timestep - 1 and use_segmentation_dict:
                            block.attn.processor.save_maps = True
                    single_obj_latent, single_obj_prompt_embeds = get_noise_pred(
                        model=model, latents=bb['latents'], timestep=timestep, guidance=guidance,
                        prompt_embeds_mask=bb['prompt_embeds_mask'], prompt_embeds=bb['prompt_embeds'], 
                        img_shapes=img_shapes, txt_seq_lens=bb['txt_seq_lens'], negative_prompt_embeds_mask=negative_prompt_embeds_mask,
                        negative_prompt_embeds=negative_prompt_embeds, negative_txt_seq_lens=negative_txt_seq_lens, 
                        true_cfg_scale=true_cfg_scale, t=t, callback_on_step_end=callback_on_step_end, 
                        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, do_true_cfg=do_true_cfg,
                        scheduler=bb['scheduler']
                        )
                    bb['latents'] = single_obj_latent
                    bb['prompt_embeds'] = single_obj_prompt_embeds
            else:
                # unconstrained generation
                latents, prompt_embeds = get_noise_pred(
                    model, latents, timestep, guidance, prompt_embeds_mask, prompt_embeds, img_shapes, txt_seq_lens,
                    negative_prompt_embeds_mask, negative_prompt_embeds, negative_txt_seq_lens, true_cfg_scale, t, 
                    callback_on_step_end, callback_on_step_end_tensor_inputs, do_true_cfg, model.scheduler
                    )
            ######

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % model.scheduler.order == 0):
                progress_bar.update()

            if XLA_AVAILABLE:
                xm.mark_step()

    model._current_timestep = None
    if output_type == "latent":
        image = latents
    else:
        latents = model._unpack_latents(latents, height, width, model.vae_scale_factor)
        latents = latents.to(model.vae.dtype)
        latents_mean = (
            torch.tensor(model.vae.config.latents_mean)
            .view(1, model.vae.config.z_dim, 1, 1, 1)
            .to(latents.device, latents.dtype)
        )
        latents_std = 1.0 / torch.tensor(model.vae.config.latents_std).view(1, model.vae.config.z_dim, 1, 1, 1).to(
            latents.device, latents.dtype
        )
        latents = latents / latents_std + latents_mean
        image = model.vae.decode(latents, return_dict=False)[0][:, :, 0]
        image = model.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    model.maybe_free_model_hooks()

    if not return_dict:
        return (image,)

    return QwenImagePipelineOutput(images=image)
