# This code is taken from diffusers. All our modifications are noted within ######

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import copy
import os

import numpy as np
import torch
from transformers import (
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPTokenizer,
    CLIPVisionModelWithProjection,
    T5EncoderModel,
    T5TokenizerFast,
)

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_torch_xla_available,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps


if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import FluxPipeline

        >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")
        >>> prompt = "A cat holding a sign that says hello world"
        >>> # Depending on the variant being used, the pipeline call will slightly vary.
        >>> # Refer to the pipeline documentation for more details.
        >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
        >>> image.save("flux.png")
        ```
""" 

######
def combine_with_segmentation_maps(bounding_boxes, latents_reshaped, batch_size):
    for idx, bb in enumerate(bounding_boxes):
        bb_latents_reshaped = bb['latents'].permute(0, 2, 1)  # shape: [1, 64, 1024]
        bb_latents_reshaped = bb_latents_reshaped.reshape(batch_size, 64, 32, 32)
        mask = bb['mask']
        mask_reshaped = mask.reshape(1, 1, 32, 32)
        latents_reshaped = torch.where(mask_reshaped, bb_latents_reshaped, latents_reshaped)
    return latents_reshaped


def combine_with_bounding_boxes(bounding_boxes, latents_reshaped, batch_size):
    for bb in bounding_boxes:
        bb_latents_reshaped = bb['latents'].permute(0, 2, 1)  # shape: [1, 64, 1024]
        bb_latents_reshaped = bb_latents_reshaped.reshape(batch_size, 64, 32, 32)
        x_min = bb['x_min']
        x_max = bb['x_max'] + 1
        y_min = bb['y_min']
        y_max = bb['y_max'] + 1
        latents_reshaped[:, :, y_min:y_max, x_min:x_max] = bb_latents_reshaped[:, :, y_min:y_max, x_min:x_max]
    return latents_reshaped


def combine_latents(latents, bounding_boxes, use_segmentation_dict):
    batch_size = latents.shape[0]
    latents_reshaped = latents.permute(0, 2, 1)  # shape: [1, 64, 1024]
    latents_reshaped = latents_reshaped.reshape(batch_size, 64, 32, 32)

    # assume background first in bounding_boxes
    latents_reshaped = combine_with_bounding_boxes([bounding_boxes[0]], latents_reshaped, batch_size)

    if use_segmentation_dict:
        latents_reshaped = combine_with_segmentation_maps(bounding_boxes[1:], latents_reshaped, batch_size)
    else:
        latents_reshaped = combine_with_bounding_boxes(bounding_boxes[1:], latents_reshaped, batch_size)
        
    latents_reshaped = latents_reshaped.reshape(batch_size, 64, 1024)  # shape: [1, 64, 1024]
    latents = latents_reshaped.permute(0, 2, 1)  # shape: [1, 1024, 64]
    return latents

# this function is made from the original diffusers
def get_noise_pred(
    model, latents, timestep, guidance, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids,
    do_true_cfg, negative_image_embeds, negative_pooled_prompt_embeds, negative_text_ids, true_cfg_scale,
    t, callback_on_step_end, callback_on_step_end_tensor_inputs, scheduler
):
    noise_pred = model.transformer(
        hidden_states=latents,
        timestep=timestep / 1000,
        guidance=guidance,
        pooled_projections=pooled_prompt_embeds,
        encoder_hidden_states=prompt_embeds,
        txt_ids=text_ids,
        img_ids=latent_image_ids,
        joint_attention_kwargs=model.joint_attention_kwargs,
        return_dict=False,
    )[0]

    if do_true_cfg:
        if negative_image_embeds is not None:
            model._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
        neg_noise_pred = model.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance,
            pooled_projections=negative_pooled_prompt_embeds,
            encoder_hidden_states=negative_prompt_embeds,
            txt_ids=negative_text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs=model.joint_attention_kwargs,
            return_dict=False,
        )[0]
        noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

    # compute the previous noisy sample x_t -> x_t-1
    latents_dtype = latents.dtype
    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    if latents.dtype != latents_dtype:
        if torch.backends.mps.is_available():
            # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
            latents = latents.to(latents_dtype)

    if callback_on_step_end is not None:
        callback_kwargs = {}
        for k in callback_on_step_end_tensor_inputs:
            callback_kwargs[k] = locals()[k]
        callback_outputs = callback_on_step_end(model, i, t, callback_kwargs)

        latents = callback_outputs.pop("latents", latents)
        prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

    return latents, prompt_embeds
######

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def modified_flux(
    model,
    combine_timestep,
    bounding_boxes,
    use_segmentation_dict,
    prompt: Union[str, List[str]] = None,
    prompt_2: Optional[Union[str, List[str]]] = None,
    negative_prompt: Union[str, List[str]] = None,
    negative_prompt_2: Optional[Union[str, List[str]]] = None,
    true_cfg_scale: float = 1.0,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 28,
    sigmas: Optional[List[float]] = None,
    guidance_scale: float = 3.5,
    num_images_per_prompt: Optional[int] = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    ip_adapter_image: Optional[PipelineImageInput] = None,
    ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
    negative_ip_adapter_image: Optional[PipelineImageInput] = None,
    negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 512,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        prompt_2 (`str` or `List[str]`, *optional*):
            The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
            will be used instead.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
            not greater than `1`).
        negative_prompt_2 (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
            `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
        true_cfg_scale (`float`, *optional*, defaults to 1.0):
            When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
        height (`int`, *optional*, defaults to model.unet.config.sample_size * model.vae_scale_factor):
            The height in pixels of the generated image. This is set to 1024 by default for the best results.
        width (`int`, *optional*, defaults to model.unet.config.sample_size * model.vae_scale_factor):
            The width in pixels of the generated image. This is set to 1024 by default for the best results.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        sigmas (`List[float]`, *optional*):
            Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
            their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
            will be used.
        guidance_scale (`float`, *optional*, defaults to 3.5):
            Guidance scale as defined in [Classifier-Free Diffusion
            Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
            of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
            `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
            the text `prompt`, usually at the expense of lower image quality.
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will be generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
            If not provided, pooled text embeddings will be generated from `prompt` input argument.
        ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
        ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
            Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
            IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
            provided, embeddings are computed from the `ip_adapter_image` input argument.
        negative_ip_adapter_image:
            (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
        negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
            Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
            IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
            provided, embeddings are computed from the `ip_adapter_image` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
            input argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
        joint_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `model.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        callback_on_step_end (`Callable`, *optional*):
            A function that calls at the end of each denoising steps during the inference. The function is called
            with the following arguments: `callback_on_step_end(model: DiffusionPipeline, step: int, timestep: int,
            callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
            `callback_on_step_end_tensor_inputs`.
        callback_on_step_end_tensor_inputs (`List`, *optional*):
            The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
            will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
            `._callback_tensor_inputs` attribute of your pipeline class.
        max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.

    Examples:

    Returns:
        [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
        is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
        images.
    """

    height = height or model.default_sample_size * model.vae_scale_factor
    width = width or model.default_sample_size * model.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    model.check_inputs(
        prompt,
        prompt_2,
        height,
        width,
        negative_prompt=negative_prompt,
        negative_prompt_2=negative_prompt_2,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        max_sequence_length=max_sequence_length,
    )

    model._guidance_scale = guidance_scale
    model._joint_attention_kwargs = joint_attention_kwargs
    model._current_timestep = None
    model._interrupt = False

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = model._execution_device

    lora_scale = (
        model.joint_attention_kwargs.get("scale", None) if model.joint_attention_kwargs is not None else None
    )
    has_neg_prompt = negative_prompt is not None or (
        negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
    )
    do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    (
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
    ) = model.encode_prompt(
        prompt=prompt,
        prompt_2=prompt_2,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=max_sequence_length,
        lora_scale=lora_scale,
    )
    ######
    # get a prompt embedding for every object
    for bb in bounding_boxes:
        (
            bb_prompt_embeds,
            bb_pooled_prompt_embeds,
            bb_text_ids,
        ) = model.encode_prompt(
            prompt=bb['prompt'],
            prompt_2=prompt_2,
            prompt_embeds=None,
            pooled_prompt_embeds=None,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
            lora_scale=lora_scale,
        )
        bb['prompt_embeds'] = bb_prompt_embeds
        bb['pooled_prompt_embeds'] = bb_pooled_prompt_embeds
        bb['text_ids'] = bb_text_ids
    ######
    if do_true_cfg:
        (
            negative_prompt_embeds,
            negative_pooled_prompt_embeds,
            negative_text_ids,
        ) = model.encode_prompt(
            prompt=negative_prompt,
            prompt_2=negative_prompt_2,
            prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=negative_pooled_prompt_embeds,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
            lora_scale=lora_scale,
        )

    # 4. Prepare latent variables
    num_channels_latents = model.transformer.config.in_channels // 4
    latents, latent_image_ids = model.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )
    ######
    # copy the latents for each object
    for bb in bounding_boxes:
        bb['latents'] = latents.clone()
        bb['latent_image_ids'] = latent_image_ids.clone()
    ######

    # 5. Prepare timesteps
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        model.scheduler.config.get("base_image_seq_len", 256),
        model.scheduler.config.get("max_image_seq_len", 4096),
        model.scheduler.config.get("base_shift", 0.5),
        model.scheduler.config.get("max_shift", 1.15),
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        model.scheduler,
        num_inference_steps,
        device,
        sigmas=sigmas,
        mu=mu,
    )
    num_warmup_steps = max(len(timesteps) - num_inference_steps * model.scheduler.order, 0)
    model._num_timesteps = len(timesteps)

    # handle guidance
    if model.transformer.config.guidance_embeds:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
        guidance = guidance.expand(latents.shape[0])
    else:
        guidance = None

    if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
        negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
    ):
        negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
        negative_ip_adapter_image = [negative_ip_adapter_image] * model.transformer.encoder_hid_proj.num_ip_adapters

    elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
        negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
    ):
        ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
        ip_adapter_image = [ip_adapter_image] * model.transformer.encoder_hid_proj.num_ip_adapters

    if model.joint_attention_kwargs is None:
        model._joint_attention_kwargs = {}

    image_embeds = None
    negative_image_embeds = None
    if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
        image_embeds = model.prepare_ip_adapter_image_embeds(
            ip_adapter_image,
            ip_adapter_image_embeds,
            device,
            batch_size * num_images_per_prompt,
        )
    if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
        negative_image_embeds = model.prepare_ip_adapter_image_embeds(
            negative_ip_adapter_image,
            negative_ip_adapter_image_embeds,
            device,
            batch_size * num_images_per_prompt,
        )

    ######
    # needed because we pass this
    negative_text_ids_to_pass = negative_text_ids if do_true_cfg else None
    ######

    # 6. Denoising loop
    model.scheduler.set_begin_index(0)
    ######
    # make copies of the schedulers for each object
    for bb in bounding_boxes:
        bb['scheduler'] = copy.deepcopy(model.scheduler)
    ######
    with model.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if model.interrupt:
                continue

            model._current_timestep = t
            if image_embeds is not None:
                model._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timestep = t.expand(latents.shape[0]).to(latents.dtype)

            ######
            if i == combine_timestep:
                latents = combine_latents(latents, bounding_boxes, use_segmentation_dict)
                # update the information within the transformers, to disable constraints
                for block in model.transformer.transformer_blocks:
                    block.attn.processor.bb = None
                for block in model.transformer.single_transformer_blocks:
                    block.attn.processor.bb = None
                model.scheduler = bounding_boxes[0]['scheduler']

            if i < combine_timestep:
                for j, bb in enumerate(bounding_boxes):
                    # update the information within the transformers, to enable correct constraints for this object
                    for block in model.transformer.transformer_blocks:
                        block.attn.processor.bb = bb
                        if i == combine_timestep - 1 and use_segmentation_dict:
                            block.attn.processor.save_maps = True
                    for block in model.transformer.single_transformer_blocks:
                        block.attn.processor.bb = bb
                        
                    single_obj_latent, single_obj_prompt_embeds = get_noise_pred(
                        model=model, latents=bb['latents'], timestep=timestep, guidance=guidance, 
                        pooled_prompt_embeds=bb['pooled_prompt_embeds'], prompt_embeds=bb['prompt_embeds'], 
                        text_ids=bb['text_ids'], latent_image_ids=bb['latent_image_ids'], do_true_cfg=do_true_cfg, 
                        negative_image_embeds=negative_image_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                        negative_text_ids=negative_text_ids_to_pass, true_cfg_scale=true_cfg_scale, 
                        t=t, callback_on_step_end=callback_on_step_end, 
                        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
                        scheduler=bb['scheduler']
                    )
                    bb['latents'] = single_obj_latent
                    bb['prompt_embeds'] = single_obj_prompt_embeds
            else:
                # unconstrained generation
                latents, prompt_embeds = get_noise_pred(
                    model=model, latents=latents, timestep=timestep, guidance=guidance, 
                    pooled_prompt_embeds=pooled_prompt_embeds, prompt_embeds=prompt_embeds, 
                    text_ids=text_ids, latent_image_ids=latent_image_ids, do_true_cfg=do_true_cfg, 
                    negative_image_embeds=negative_image_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 
                    negative_text_ids=negative_text_ids_to_pass, true_cfg_scale=true_cfg_scale,
                    t=t, callback_on_step_end=callback_on_step_end, 
                    callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
                    scheduler=model.scheduler
                )
            ######
            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % model.scheduler.order == 0):
                progress_bar.update()

            if XLA_AVAILABLE:
                xm.mark_step()

    model._current_timestep = None

    if output_type == "latent":
        image = latents
    else:
        latents = model._unpack_latents(latents, height, width, model.vae_scale_factor)
        latents = (latents / model.vae.config.scaling_factor) + model.vae.config.shift_factor
        image = model.vae.decode(latents, return_dict=False)[0]
        image = model.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    model.maybe_free_model_hooks()

    if not return_dict:
        return (image,)

    return FluxPipelineOutput(images=image)
