

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from PIL import Image

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL

from models.camctrl_transformer import CamCtrlFluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_torch_xla_available,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput


if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


logger = logging.get_logger(__name__) 


def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.16,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):

    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


class FluxControlSingleScaleRemovalPipeline(
    DiffusionPipeline,
    FluxLoraLoaderMixin,
    FromSingleFileMixin,
    TextualInversionLoaderMixin,
):

    model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
    _optional_components = []
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(
        self,
        scheduler: FlowMatchEulerDiscreteScheduler,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        text_encoder_2: T5EncoderModel,
        tokenizer_2: T5TokenizerFast,
        transformer: CamCtrlFluxTransformer2DModel,
    ):
        super().__init__()

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            transformer=transformer,
            scheduler=scheduler,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
        self.vae_latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16

        self.image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae_latent_channels
        )
        self.mask_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor,
            do_resize=True,
            do_convert_grayscale=True,
            do_normalize=False,
            do_binarize=True,
        )
        
        self.tokenizer_max_length = (
            self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
        )
        self.default_sample_size = 128


    def _get_t5_prompt_embeds(
        self,
        prompt: Union[str, List[str]] = None,
        num_images_per_prompt: int = 1,
        max_sequence_length: int = 512,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        device = device or self._execution_device
        dtype = dtype or self.text_encoder.dtype

        prompt = [prompt] if isinstance(prompt, str) else prompt
        batch_size = len(prompt)

        if isinstance(self, TextualInversionLoaderMixin):
            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)

        text_inputs = self.tokenizer_2(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            return_length=False,
            return_overflowing_tokens=False,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
            logger.warning(
                "The following part of your input was truncated because `max_sequence_length` is set to "
                f" {max_sequence_length} tokens: {removed_text}"
            )

        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]

        dtype = self.text_encoder_2.dtype
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

        _, seq_len, _ = prompt_embeds.shape


        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        return prompt_embeds

    def _get_clip_prompt_embeds(
        self,
        prompt: Union[str, List[str]],
        num_images_per_prompt: int = 1,
        device: Optional[torch.device] = None,
    ):
        device = device or self._execution_device

        prompt = [prompt] if isinstance(prompt, str) else prompt
        batch_size = len(prompt)

        if isinstance(self, TextualInversionLoaderMixin):
            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer_max_length,
            truncation=True,
            return_overflowing_tokens=False,
            return_length=False,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer_max_length} tokens: {removed_text}"
            )
        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)

        prompt_embeds = prompt_embeds.pooler_output
        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)


        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)

        return prompt_embeds

    def encode_prompt(
        self,
        prompt: Union[str, List[str]],
        prompt_2: Union[str, List[str]],
        device: Optional[torch.device] = None,
        num_images_per_prompt: int = 1,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        max_sequence_length: int = 512,
        lora_scale: Optional[float] = None,
    ):

        device = device or self._execution_device

        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
            self._lora_scale = lora_scale


            if self.text_encoder is not None and USE_PEFT_BACKEND:
                scale_lora_layers(self.text_encoder, lora_scale)
            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
                scale_lora_layers(self.text_encoder_2, lora_scale)

        prompt = [prompt] if isinstance(prompt, str) else prompt

        if prompt_embeds is None:
            prompt_2 = prompt_2 or prompt
            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

            pooled_prompt_embeds = self._get_clip_prompt_embeds(
                prompt=prompt,
                device=device,
                num_images_per_prompt=num_images_per_prompt,
            )
            prompt_embeds = self._get_t5_prompt_embeds(
                prompt=prompt_2,
                num_images_per_prompt=num_images_per_prompt,
                max_sequence_length=max_sequence_length,
                device=device,
            )

        if self.text_encoder is not None:
            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:

                unscale_lora_layers(self.text_encoder, lora_scale)

        if self.text_encoder_2 is not None:
            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:

                unscale_lora_layers(self.text_encoder_2, lora_scale)

        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

        return prompt_embeds, pooled_prompt_embeds, text_ids

    def check_inputs(
        self,
        prompt,
        prompt_2,
        height,
        width,
        prompt_embeds=None,
        pooled_prompt_embeds=None,
        callback_on_step_end_tensor_inputs=None,
        max_sequence_length=None,
    ):
        if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
            logger.warning(
                f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
            )

        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt_2 is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")

        if prompt_embeds is not None and pooled_prompt_embeds is None:
            raise ValueError(
                "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
            )

        if max_sequence_length is not None and max_sequence_length > 512:
            raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")

    @staticmethod

    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
        latent_image_ids = torch.zeros(height, width, 3)
        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

        latent_image_ids = latent_image_ids.reshape(
            latent_image_id_height * latent_image_id_width, latent_image_id_channels
        )

        return latent_image_ids.to(device=device, dtype=dtype)

    @staticmethod

    def _pack_latents(latents, batch_size, num_channels_latents, height, width):
        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
        latents = latents.permute(0, 2, 4, 1, 3, 5) # bs, h, w, num, 2, 2
        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
        return latents

    @staticmethod

    def _unpack_latents(latents, height, width, vae_scale_factor):
        batch_size, num_patches, channels = latents.shape

        height = 2 * (int(height) // (vae_scale_factor * 2))
        width = 2 * (int(width) // (vae_scale_factor * 2))

        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
        latents = latents.permute(0, 3, 1, 4, 2, 5)

        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

        return latents

    def enable_vae_slicing(self):

        self.vae.enable_slicing()

    def disable_vae_slicing(self):

        self.vae.disable_slicing()

    def enable_vae_tiling(self):

        self.vae.enable_tiling()

    def disable_vae_tiling(self):

        self.vae.disable_tiling()


    def prepare_latents(
        self,
        batch_size,
        num_channels_latents,
        height,
        width,
        dtype,
        device,
        generator,
        latents=None,
    ):
        height = 2 * (int(height) // (self.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae_scale_factor * 2))

        shape = (batch_size, num_channels_latents, height, width)

        if latents is not None:
            latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
            return latents.to(device=device, dtype=dtype), latent_image_ids

        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

        return latents, latent_image_ids
    
    def crop_fov(
        self, 
        image, 
        ratio,
        ):
        width, height = image.shape[:2]
        
        new_width = width * ratio
        new_height = height * ratio
        left = np.ceil((width - new_width)/2.)
        top = np.ceil((height - new_height)/2.)
        right = np.floor((width + new_width)/2.)
        bottom = np.floor((height + new_height)/2.)
        
        cropped = image[int(left):int(right), int(top):int(bottom), ...]
        return cropped
    
    def simulate_zoom_image(
        self,
        input_image: Image.Image,
        zoom: int,
        output_size: tuple = None,
    ) -> Image.Image:

        assert zoom >= 1

        if output_size is None:
            output_size = input_image.size

        img_np = np.array(input_image)

        if zoom == 1:
            return input_image.copy()

        # 裁剪并放大
        crop_ratio = 1.0 / zoom
        cropped_np = self.crop_fov(img_np, crop_ratio)
        cropped_img = Image.fromarray(cropped_np)
        zoomed_img = cropped_img.resize(output_size, Image.LANCZOS)

        return zoomed_img
    
    
    def check_mask_and_generate_focal_versions(
        self,
        mask_pil: Image.Image,
        input_image_pil: Image.Image,
        tolerance: float = 0.05
    ):

        assert isinstance(mask_pil, Image.Image), "mask_pil must be a PIL.Image"
        assert isinstance(input_image_pil, Image.Image), "input_image_pil must be a PIL.Image"


        mask_1x = mask_pil
        mask_2x = self.simulate_zoom_image(mask_pil, zoom=2)
        mask_3x = self.simulate_zoom_image(mask_pil, zoom=3)
        mask_4x = self.simulate_zoom_image(mask_pil, zoom=4)


        def get_white_pixel_count(pil_img):
            mask_np = np.array(pil_img.convert("L"))
            return np.sum(mask_np > 127)

        base_count = get_white_pixel_count(mask_1x)
        actual_4x_count = get_white_pixel_count(mask_4x)
        expected_4x_count = base_count * (4 ** 2)
        diff_ratio = abs(actual_4x_count - expected_4x_count) / expected_4x_count

        print(f"[Mask Check] base={base_count}, actual_4x={actual_4x_count}, "
            f"expected_4x={expected_4x_count}, diff={diff_ratio:.2%}")

        is_complete = diff_ratio <= tolerance


        if is_complete:
            print("mask_complete....")
            image_versions = {
                "1x": input_image_pil,
                "2x": self.simulate_zoom_image(input_image_pil, zoom=2),
                "3x": self.simulate_zoom_image(input_image_pil, zoom=3),
                "4x": self.simulate_zoom_image(input_image_pil, zoom=4),
            }
        else:
            image_versions = {
                "1x": input_image_pil,
                "2x": self.simulate_zoom_image(input_image_pil, zoom=2),
                "3x": self.simulate_zoom_image(input_image_pil, zoom=3),
                "4x": self.simulate_zoom_image(input_image_pil, zoom=4),
            }

        mask_versions = {
            "1x": mask_1x,
            "2x": mask_2x,
            "3x": mask_3x,
            "4x": mask_4x,
        }

        return mask_versions, image_versions, is_complete



    def prepare_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        if isinstance(image, torch.Tensor):
            pass
        else:
            image = self.image_processor.preprocess(image, height=height, width=width)

        image_batch_size = image.shape[0]

        if image_batch_size == 1:
            repeat_by = batch_size
        else:

            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image


    def prepare_image_with_mask_camera(
        self,
        image,
        mask,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance = False,
        camera_model= None,
    ):
        

        if isinstance(image, torch.Tensor):
            pass
        else:
            image_pil = image
            mask_pil = mask
            image = self.image_processor.preprocess(image, height=height, width=width)
        
        image_batch_size = image.shape[0]
        if image_batch_size == 1:
            repeat_by = batch_size
        else:

            repeat_by = num_images_per_prompt
        image = image.repeat_interleave(repeat_by, dim=0)
        
        img_for_calib = image[0].clone().detach().to(device=device)
        image = image.to(device=device, dtype=dtype)
        

        if camera_model is not None:
          results = camera_model.calibrate(img_for_calib)
          camera = results["camera"]
          focal = camera.f[0, 1].item()
        
        

        mask_focals, image_focals, is_complete = self.check_mask_and_generate_focal_versions(
            mask_pil=mask_pil,
            input_image_pil=image_pil,
            tolerance=0.1
        )
        

        if isinstance(mask, torch.Tensor):
            pass
        else:
            masks = self.mask_processor.preprocess(mask, height=height, width=width)
        masks = masks.repeat_interleave(repeat_by, dim=0)
        masks = masks.to(device=device, dtype=dtype)


        masked_image = image.clone()
        masked_image[(masks > 0.5).repeat(1, 3, 1, 1)] = -1
        foreground_images = image.clone()
        foreground_images[(masks < 0.5).repeat(1, 3, 1, 1)] = -1
        

        image_latents = self.vae.encode(image.to(self.vae.dtype)).latent_dist.sample()
        image_latents = (
            image_latents - self.vae.config.shift_factor
        ) * self.vae.config.scaling_factor
        image_latents = image_latents.to(dtype)
        foreground_image_latents = self.vae.encode(foreground_images.to(self.vae.dtype)).latent_dist.sample()
        foreground_image_latents = (
            foreground_image_latents - self.vae.config.shift_factor
        ) * self.vae.config.scaling_factor
        foreground_image_latents = foreground_image_latents.to(dtype)
        
        masked_image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
        masked_image_latents = (
            masked_image_latents - self.vae.config.shift_factor
        ) * self.vae.config.scaling_factor
        masked_image_latents = masked_image_latents.to(dtype)

        masks = torch.nn.functional.interpolate(
            masks, size=(masks.shape[2] // self.vae_scale_factor , masks.shape[3] // self.vae_scale_factor)
            ).to(dtype)
        masks = 1 - masks

        masks = masks.repeat(1, masked_image_latents.shape[1], 1, 1)

        control_image = torch.cat([masked_image_latents, foreground_image_latents, masks], dim=1)


        packed_control_image = self._pack_latents(
            control_image,
            batch_size * num_images_per_prompt,
            control_image.shape[1],
            control_image.shape[2],
            control_image.shape[3],
        )
        
        if do_classifier_free_guidance:
            packed_control_image = torch.cat([packed_control_image] * 2)

        return packed_control_image, height, width
    

    
    def prepare_image_with_mask_focal(
        self,
        image,
        mask,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        camera_model=None,
    ):

        if not isinstance(image, torch.Tensor):
            image_pil = image
        else:
            raise NotImplementedError("只支持 PIL 输入方便做 multi-focal")

        if not isinstance(mask, torch.Tensor):
            mask_pil = mask
        else:
            raise NotImplementedError("只支持 PIL 输入方便做 multi-focal")


        mask_focals, image_focals, is_complete = self.check_mask_and_generate_focal_versions(
            mask_pil=mask_pil,
            input_image_pil=image_pil,
            tolerance=0.1
        )

        focal_control_images = {}


        image_latents_1x = None
        mask_tensor_1x = None

        for focal_key in mask_focals.keys():
            mask_version_pil = mask_focals[focal_key]
            image_version_pil = image_focals[focal_key]

            image_tensor = self.image_processor.preprocess(image_version_pil, height=height, width=width)
            mask_tensor = self.mask_processor.preprocess(mask_version_pil, height=height, width=width)

            image_tensor = image_tensor.repeat_interleave(batch_size, dim=0).to(device=device, dtype=dtype)
            mask_tensor = mask_tensor.repeat_interleave(batch_size, dim=0).to(device=device, dtype=dtype)

            masked_image = image_tensor.clone()
            masked_image[(mask_tensor > 0.5).repeat(1, 3, 1, 1)] = -1

            foreground_images = image_tensor.clone()
            foreground_images[(mask_tensor < 0.5).repeat(1, 3, 1, 1)] = -1

            image_latents = self.vae.encode(image_tensor.to(self.vae.dtype)).latent_dist.sample()
            image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
            image_latents = image_latents.to(dtype)

            masked_image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
            masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
            masked_image_latents = masked_image_latents.to(dtype)

            foreground_image_latents = self.vae.encode(foreground_images.to(self.vae.dtype)).latent_dist.sample()
            foreground_image_latents = (
                foreground_image_latents - self.vae.config.shift_factor
            ) * self.vae.config.scaling_factor
            foreground_image_latents = foreground_image_latents.to(dtype)

            masks = torch.nn.functional.interpolate(
                mask_tensor, size=(mask_tensor.shape[2] // self.vae_scale_factor, mask_tensor.shape[3] // self.vae_scale_factor)
            ).to(dtype)
            masks = 1 - masks
            masks = masks.repeat(1, masked_image_latents.shape[1], 1, 1)


            control_image = torch.cat([masked_image_latents, foreground_image_latents, masks], dim=1)


            packed_control_image = self._pack_latents(
                control_image,
                batch_size,
                control_image.shape[1],
                control_image.shape[2],
                control_image.shape[3],
            )

            if do_classifier_free_guidance:
                packed_control_image = torch.cat([packed_control_image] * 2)

            focal_control_images[focal_key] = packed_control_image

            if focal_key == "1x":
                image_latents_1x = image_latents  # [B, C, H, W]
                mask_tensor_4x = 1 - masks      # [B, 1, H, W]


        return focal_control_images, height, width, image_latents_1x, mask_tensor_4x

    
    def compute_crop_region(self,height,width,ratio):

        new_width = width * ratio
        new_height = height * ratio

        left = int(np.ceil((width - new_width) / 2.))
        top = int(np.ceil((height - new_height) / 2.))
        right = int(np.floor((width + new_width) / 2.))
        bottom = int(np.floor((height + new_height) / 2.))
        return left, top, right, bottom
    
    def fuse_inverse_zoom_latents_with_mask(
        self,
        latent_1x: torch.Tensor,
        latent_4x: torch.Tensor,
        mask: torch.Tensor,
        zoom: int = 4,
    ):

        B, C, H, W = latent_1x.shape # 1,16,128,128


        crop_ratio = 1.0 / zoom
        left, top, right, bottom = self.compute_crop_region(H, W, crop_ratio)
        crop_h, crop_w = bottom - top, right - left

        latent_4x_cropped = F.interpolate(latent_4x, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
        mask_cropped = F.interpolate(mask, size=(crop_h, crop_w), mode="bilinear", align_corners=False)


        pad_latent_4x = latent_1x.clone()  
        pad_latent_4x[:, :, top:bottom, left:right] = latent_4x_cropped

        pad_mask = torch.zeros_like(mask)
        pad_mask[:, :, top:bottom, left:right] = mask_cropped
        

        if mask.shape[1] == 1:
            mask_broadcast = pad_mask.repeat(1, C, 1, 1)
        else:
            mask_broadcast = pad_mask

        fused_latent = latent_1x * (1 - mask_broadcast) + pad_latent_4x * mask_broadcast

        return fused_latent, pad_latent_4x
    
    
    def fuse_inverse_zoom_images_with_mask_PIL(
        self,
        image_1x: Image.Image,  
        image_4x: Image.Image, 
        mask_1x: Image.Image,   
        zoom: int = 4,
    ):

        W, H = image_1x.size


        crop_ratio = 1.0 / zoom
        left, top, right, bottom = self.compute_crop_region(H, W, crop_ratio)
        crop_W = right - left
        crop_H = bottom - top


        image_4x_crop = image_4x.resize((crop_W, crop_H), Image.LANCZOS)



        paste_layer = image_1x.copy()
        paste_layer.paste(image_4x_crop, (left, top))
        vis_paste_layer = paste_layer.resize((960,540),Image.LANCZOS)

        img_np = np.array(image_1x)
        inpainting_np = np.array(paste_layer)
        mask_np = np.array(mask_1x)

        mask_binary = mask_np > 127

        fused_np = img_np.copy()
        fused_np[mask_binary] = inpainting_np[mask_binary]

        fused_img = Image.fromarray(fused_np)
        
        vis = fused_img.resize((960,540),Image.LANCZOS)
        vis.save("fuse_4x_with_1x.png")

        return fused_img

    
    @property
    def guidance_scale(self):
        return self._guidance_scale

    @property
    def joint_attention_kwargs(self):
        return self._joint_attention_kwargs

    @property
    def num_timesteps(self):
        return self._num_timesteps

    @property
    def interrupt(self):
        return self._interrupt

    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
        control_image: PipelineImageInput = None,
        control_mask: PipelineImageInput = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 28,
        sigmas: Optional[List[float]] = None,
        guidance_scale: float = 3.5,
        num_images_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        max_sequence_length: int = 512,
        camera_model=None,
        test_zomm_ratio=None,
    ):
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor
        
        iamge_1x_pixel = control_image
        mask_1x_pixel = control_mask

        self.check_inputs(
            prompt,
            prompt_2,
            height,
            width,
            prompt_embeds=prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
            max_sequence_length=max_sequence_length,
        )

        self._guidance_scale = guidance_scale
        self._joint_attention_kwargs = joint_attention_kwargs
        self._interrupt = False

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # 3. Prepare text embeddings
        lora_scale = (
            self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
        )
        (
            prompt_embeds,
            pooled_prompt_embeds,
            text_ids,
        ) = self.encode_prompt(
            prompt=prompt,
            prompt_2=prompt_2,
            prompt_embeds=prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
            lora_scale=lora_scale,
        )

        # 4. Prepare latent variables
        num_channels_latents = (self.transformer.config.in_channels)  // 16 # 16

        control_images_dict, height, width,image_latents_1x, mask = self.prepare_image_with_mask_focal(
            image=control_image,
            mask=control_mask,
            width=width,
            height=height,
            batch_size=batch_size * num_images_per_prompt,
            num_images_per_prompt=num_images_per_prompt,
            device=device,
            dtype=prompt_embeds.dtype,
            camera_model=camera_model,
        )
        
        if test_zomm_ratio is not None:
            print("test zoom ratio:",test_zomm_ratio)
            control_image = control_images_dict[test_zomm_ratio]
        else:
            control_image = control_images_dict["1x"]
            

        latents, latent_image_ids = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 5. Prepare timesteps
        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
        image_seq_len = latents.shape[1]
        mu = calculate_shift(
            image_seq_len,
            self.scheduler.config.get("base_image_seq_len", 256),
            self.scheduler.config.get("max_image_seq_len", 4096),
            self.scheduler.config.get("base_shift", 0.5),
            self.scheduler.config.get("max_shift", 1.16),
        )
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler,
            num_inference_steps,
            device,
            sigmas=sigmas,
            mu=mu,
        )
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
        self._num_timesteps = len(timesteps)

        # handle guidance
        if self.transformer.config.guidance_embeds:
            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
            guidance = guidance.expand(latents.shape[0])
        else:
            guidance = None

        # 6. Denoising loop
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                latent_model_input = torch.cat([latents, control_image], dim=2)

                timestep = t.expand(latents.shape[0]).to(latents.dtype)

                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]


                latents_dtype = latents.dtype
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():

                        latents = latents.to(latents_dtype)

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()

                if XLA_AVAILABLE:
                    xm.mark_step()

        if output_type == "latent":
            image = latents
        else:
            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
            image = self.vae.decode(latents, return_dict=False)[0]
            image = self.image_processor.postprocess(image, output_type=output_type)
            

        self.maybe_free_model_hooks()

        if not return_dict:
            return (image,)

        return FluxPipelineOutput(images=image)
