from PIL import Image
import copy
import json
import math
import torch
import yaml, os
import base64
from tqdm import tqdm
from diffusers import DiffusionPipeline
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
import torchvision.transforms as T
from .transformer import tranformer_forward
from .condition import Condition
from .pipeline_tools import process_entity_masks, prepare_text_input_eligen
from .pipeline_tools import encode_images, decode_images, prepare_text_input, prepare_text_input_eligen, encode_poses
from .pipeline_tools import visualize_masks

from diffusers.pipelines.flux.pipeline_flux import (
    FluxPipelineOutput,
    calculate_shift,
    retrieve_timesteps,
    np,
)

from ..utils.scene import DiffusionScene
from ..utils.prompt import gen_prompt, edit_prompt, identity_prompt, gen_prompt_2d, gen_prompt_new
from ..utils.prompt_optimize import system_prompt, task_prompt, user_prompt
from ..utils.vlm import vlm_request, extract_and_parse_json, extract_and_parse_list



def get_config(config_path: str = None):
    config_path = config_path or os.environ.get("XFL_CONFIG")
    if not config_path:
        return {}
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def prepare_params(
    prompt: Union[str, List[str]] = None,
    prompt_2: Optional[Union[str, List[str]]] = None,
    height: Optional[int] = 512,
    width: Optional[int] = 512,
    num_inference_steps: int = 50,
    timesteps: List[int] = None,
    guidance_scale: float = 3.5,
    num_images_per_prompt: Optional[int] = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 512,
    **kwargs: dict,
):
    return (
        prompt,
        prompt_2,
        height,
        width,
        num_inference_steps,
        timesteps,
        guidance_scale,
        num_images_per_prompt,
        generator,
        latents,
        prompt_embeds,
        pooled_prompt_embeds,
        output_type,
        return_dict,
        joint_attention_kwargs,
        callback_on_step_end,
        callback_on_step_end_tensor_inputs,
        max_sequence_length,
        kwargs,
    )


def seed_everything(seed: int = 42):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    np.random.seed(seed)


@torch.no_grad()
def generate(
    pipeline: FluxPipeline,
    conditions: List[Condition] = None,
    config_path: str = None,
    model_config: Optional[Dict[str, Any]] = {},
    condition_scale: float = 1.0,
    default_lora: bool = False,
    cfg_scale=1.0,
    negative_prompt="",
    optimize=False,
    eligen_entity_masks_pil=None,
    layout=None,
    start=0,
    end=5,
    optim_step=1,
    inv_step=0,
    save_tmp_image=None,
    **params: dict,
):
    model_config = model_config or get_config(config_path).get("model", {})
    if condition_scale != 1:
        for name, module in pipeline.transformer.named_modules():
            if not name.endswith(".attn"):
                continue
            module.c_factor = torch.ones(1, 1) * condition_scale

    self = pipeline
    (
        prompt,
        prompt_2,
        height,
        width,
        num_inference_steps,
        timesteps,
        guidance_scale,
        num_images_per_prompt,
        generator,
        latents,
        prompt_embeds,
        pooled_prompt_embeds,
        output_type,
        return_dict,
        joint_attention_kwargs,
        callback_on_step_end,
        callback_on_step_end_tensor_inputs,
        max_sequence_length,
        kwargs,
    ) = prepare_params(**params)

    height = height or self.default_sample_size * self.vae_scale_factor
    width = width or self.default_sample_size * self.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt,
        prompt_2,
        height,
        width,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        max_sequence_length=max_sequence_length,
    )

    self._guidance_scale = guidance_scale
    self._joint_attention_kwargs = joint_attention_kwargs
    self._interrupt = False

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    lora_scale = (
        self.joint_attention_kwargs.get("scale", None)
        if self.joint_attention_kwargs is not None
        else None
    )
    (
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
    ) = self.encode_prompt(
        prompt=prompt,
        prompt_2=prompt_2,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=max_sequence_length,
        lora_scale=lora_scale,
    )
    if cfg_scale != 1.0:
        (
            negative_prompt_embeds,
            negative_pooled_prompt_embeds,
            _,
        ) = self.encode_prompt(
            prompt=negative_prompt,
            prompt_2=None,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
            lora_scale=lora_scale,
        )

    # 4. Prepare latent variables
    num_channels_latents = self.transformer.config.in_channels // 4
    latents, latent_image_ids = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    # 4.1. Prepare conditions
    condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
    use_condition = conditions is not None or []
    if use_condition:
        assert len(conditions) <= 1, "Only one condition is supported for now."
        if not default_lora:
            pipeline.set_adapters(conditions[0].condition_type)
        for condition in conditions:
            tokens, ids, type_id = condition.encode(self, ids=latent_image_ids)
            condition_latents.append(tokens)  # [batch_size, token_n, token_dim]
            condition_ids.append(ids)  # [token_n, id_dim(3)]
            condition_type_ids.append(type_id)  # [token_n, 1]
        condition_latents = torch.cat(condition_latents, dim=1)
        condition_ids = torch.cat(condition_ids, dim=0) if ids is not None else None
        condition_type_ids = torch.cat(condition_type_ids, dim=0)

    # 5. Prepare timesteps
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        self.scheduler.config.base_image_seq_len,
        self.scheduler.config.max_image_seq_len,
        self.scheduler.config.base_shift,
        self.scheduler.config.max_shift,
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        self.scheduler,
        num_inference_steps,
        device,
        timesteps,
        sigmas,
        mu=mu,
    )
    num_warmup_steps = max(
        len(timesteps) - num_inference_steps * self.scheduler.order, 0
    )
    self._num_timesteps = len(timesteps)

    if kwargs.get("eligen_entity_prompts", None) and kwargs.get("eligen_entity_masks", None):
        eligen_entity_prompts = kwargs["eligen_entity_prompts"]
        eligen_entity_masks = kwargs["eligen_entity_masks"]
        eligen_kwargs = prepare_text_input_eligen(
            self, eligen_entity_prompts, eligen_entity_masks, orients=kwargs.get("orient", None),
        )
        kwargs["eligen_kwargs"] = eligen_kwargs

    messages = None
    # 6. Denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue

            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timestep = t.expand(latents.shape[0]).to(latents.dtype)

            # handle guidance
            if self.transformer.config.guidance_embeds:
                guidance = torch.tensor([guidance_scale], device=device)
                guidance = guidance.expand(latents.shape[0])
            else:
                guidance = None
            # import ipdb; ipdb.set_trace()
            noise_pred = tranformer_forward(
                self.transformer,
                model_config=model_config,
                # Inputs of the condition (new feature)
                condition_latents=condition_latents if use_condition else None,
                condition_ids=condition_ids if use_condition else None,
                condition_type_ids=condition_type_ids if use_condition else None,
                # Inputs to the original transformer
                hidden_states=latents,
                # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                timestep=timestep / 1000,
                guidance=guidance,
                pooled_projections=pooled_prompt_embeds,
                encoder_hidden_states=prompt_embeds,
                txt_ids=text_ids,
                img_ids=latent_image_ids,
                joint_attention_kwargs=self.joint_attention_kwargs,
                return_dict=False,
                **kwargs,
            )[0]

            if cfg_scale != 1.0:
                # Negative side
                noise_pred_nega = tranformer_forward(
                    self.transformer,
                    model_config=model_config,
                    # Inputs of the condition (new feature)
                    condition_latents=None,
                    condition_ids=None,
                    condition_type_ids=None,
                    # Inputs to the original transformer
                    hidden_states=latents,
                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=negative_pooled_prompt_embeds,
                    encoder_hidden_states=negative_prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                    **kwargs,
                )[0]
                noise_pred = noise_pred_nega + cfg_scale * (noise_pred - noise_pred_nega)

            if save_tmp_image is not None:
                os.makedirs(f'{save_tmp_image}/tmp', exist_ok=True)
                latents_x0 = latents - noise_pred * timestep / 1000
                image_tmp = self._unpack_latents(latents_x0, height, width, self.vae_scale_factor).to(self.dtype)
                image_tmp = (image_tmp / self.vae.config.scaling_factor) + self.vae.config.shift_factor
                image_tmp = self.vae.decode(image_tmp, return_dict=False)[0]
                image_tmp = self.image_processor.postprocess(image_tmp, output_type='pil')
                image_tmp[0].save(f'{save_tmp_image}/tmp/timestep_{i}.png')
                image_tmp_mask = visualize_masks(image_tmp[0], eligen_entity_masks_pil, eligen_entity_prompts)
                image_tmp_mask.save(f'{save_tmp_image}/tmp/timestep_visual_{i}.png')
                
                if optimize and start<=i<=end and i%optim_step==0:
                    if os.path.exists(f'{save_tmp_image}/tmp/data_{i}.json'):
                        with open(f'{save_tmp_image}/tmp/data_{i}.json', 'r') as f:
                            data = json.load(f)
                        messages = data['messages']
                    else:
                        # data = json_generation(prompt, f'{save_tmp_image}/tmp/timestep_{i}.png', eligen_entity_prompts, layout['entity_layout'])
                        data = json_generation(prompt, [f'{save_tmp_image}/tmp/timestep_{i}.png'], eligen_entity_prompts, layout['entity_layout'], messages=messages)
                        messages = data['messages']
                        with open(f'{save_tmp_image}/tmp/data_{i}.json', 'w') as f:
                            json.dump(data, f, indent=4)
                    layout['entity_layout'] = data['ans_json']
                    total_move = layout.get('total_move', None)
                    total_move_y = layout.get('total_move_y', None)
                    depth_all, total_move = generate_scene(layout, total_move=total_move, total_move_y=total_move_y)
                    Image.fromarray(depth_all[-1]).save(f'{save_tmp_image}/tmp/render_{i}.png') 
                    condition_, eligen_entity_prompts, eligen_entity_masks, eligen_entity_masks_pil = prepare_data(self, layout, depth_all, height)
                    
                    condition_latents = []
                    for condition in conditions:
                        tokens, ids, type_id = condition.encode(self, ids=latent_image_ids)
                        condition_latents.append(tokens)  # [batch_size, token_n, token_dim]
                    condition_latents = torch.cat(condition_latents, dim=1)
                    
                    eligen_kwargs = prepare_text_input_eligen(
                        self, eligen_entity_prompts, eligen_entity_masks, orients=kwargs.get("orient", None),
                    )
                    kwargs["eligen_kwargs"] = eligen_kwargs

                    # if inv_step > 0 and i >= 3:
                    #     inv_timestep = timesteps[i-inv_step].expand(latents.shape[0]).to(latents.dtype)
                    #     inv_sample = latents + (sigma_next - timestep / 1000) * model_output

                    noise_pred = tranformer_forward(
                        self.transformer,
                        model_config=model_config,
                        # Inputs of the condition (new feature)
                        condition_latents=condition_latents if use_condition else None,
                        condition_ids=condition_ids if use_condition else None,
                        condition_type_ids=condition_type_ids if use_condition else None,
                        # Inputs to the original transformer
                        hidden_states=latents,
                        # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                        timestep=timestep / 1000,
                        guidance=guidance,
                        pooled_projections=pooled_prompt_embeds,
                        encoder_hidden_states=prompt_embeds,
                        txt_ids=text_ids,
                        img_ids=latent_image_ids,
                        joint_attention_kwargs=self.joint_attention_kwargs,
                        return_dict=False,
                        **kwargs,
                    )[0]
                
            # compute the previous noisy sample x_t -> x_t-1
            latents_dtype = latents.dtype
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

            if latents.dtype != latents_dtype:
                if torch.backends.mps.is_available():
                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                    latents = latents.to(latents_dtype)

            if callback_on_step_end is not None:
                callback_kwargs = {}
                for k in callback_on_step_end_tensor_inputs:
                    callback_kwargs[k] = locals()[k]
                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                latents = callback_outputs.pop("latents", latents)
                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

            # call the callback, if provided
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
            ):
                progress_bar.update()

    if output_type == "latent":
        image = latents        
    else:
        latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
        latents = (
            latents / self.vae.config.scaling_factor
        ) + self.vae.config.shift_factor
        image = self.vae.decode(latents, return_dict=False)[0]
        image = self.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    self.maybe_free_model_hooks()

    if condition_scale != 1:
        for name, module in pipeline.transformer.named_modules():
            if not name.endswith(".attn"):
                continue
            del module.c_factor
    
    if not return_dict:
        return (image,)

    return FluxPipelineOutput(images=image)

# Function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# def json_generation(caption, image_path=None, entities=None, messages=None):
#     base64_image = encode_image(image_path)
#     messages=[
#         {
#             "role": "user",
#             "content": [
#                 {"type": "text", "text": f"{edit_prompt.replace('<caption>', caption).replace('<entities>', json.dumps(entities)).replace('<layout>', json.dumps(layout))}"},
#                 {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
#             ]
#         }
#     ]
#     tries = 0
#     ans_json = None
#     while ans_json is None and tries < 3:
#         content = vlm_request(messages, model='chatgpt-4o-latest')
#         answer = content.split('</think>')[-1]
#         ans_json = extract_and_parse_json(answer)
#         if ans_json is None:
#             ans_json = extract_json_from_string(answer)
#         tries = tries +1
#     if isinstance(ans_json, dict) and 'optimized_layout' in ans_json.keys(): ans_json = ans_json['optimized_layout']

#     data = {
#         'ans_json': ans_json,
#         'content': content,
#     }

#     return data

def json_generation(caption, current_image_paths=None, entities=None, layout=None, messages=None):
    """
    Generates JSON based on image(s), caption, and entities,
    maintaining conversation history with image paths (not base64).

    Args:
        caption (str): The caption for the image(s) and overall turn.
        current_image_paths (list[str], optional): A list of paths to image files
                                                     for the current turn. Defaults to None.
        entities (list, optional): List of entities detected in the image(s). Defaults to None.
        messages (list, optional): Previous conversation history in a custom format
                                     (list of dicts with 'role' and 'content',
                                      image content uses {'local_path': '...'} for paths).
                                     Defaults to None for the first turn.
        layout (dict, optional): The current layout structure to include in the prompt.
                                 Defaults to None.

    Returns:
        dict: A dictionary containing:
            'ans_json': The extracted and parsed JSON data.
            'content': The raw content of the last VLM response.
            'messages': The updated list representing the full conversation history
                        (image parts in user messages use {'local_path': '...'}).
    """

    # 1. Initialize messages if it's the first turn
    if messages is None:
        messages = [
            {"role": "system", "content": system_prompt},
        ]

    # 2. Construct the current user message content using image paths format
    user_content_parts = []

    # Prepare data for the prompt text
    entities_str = json.dumps(entities) if entities is not None else "[]"
    layout_str = json.dumps(layout) if layout is not None else "{}"

    # Construct the user prompt text
    user_text = user_prompt.replace('<caption>', caption) \
                         .replace('<entities>', entities_str) \
                         .replace('<layout>', layout_str)

    user_content_parts.append({"type": "text", "text": task_prompt + user_text})

    # Add image parts using local paths storage format
    if current_image_paths:
        for path in current_image_paths:
            # Store the path directly within the image_url structure using 'local_path'
            # THIS IS THE FORMAT STORED IN HISTORY
            user_content_parts.append({"type": "image_url", "image_url": {"local_path": path}})
        print(f"Stored {len(current_image_paths)} image paths in the current user message (history format).")


    current_user_message = {
        "role": "user",
        "content": user_content_parts
    }

    # 3. Append the current user message to the history (contains local_path)
    messages.append(current_user_message)

    # --- Prepare Messages FOR VLM Call (Encode Images to Base64) ---
    # Create a deep copy of the messages list to avoid modifying the history
    messages_for_vlm = copy.deepcopy(messages)

    # Iterate through the copied messages and encode image paths to base64
    # THIS LOOP IS IMPLEMENTED *INSIDE* json_generation AS REQUESTED
    for msg in messages_for_vlm:
        if msg.get("role") == "user" and isinstance(msg.get("content"), list):
            new_content_parts = []
            for part in msg["content"]:
                # Check if the part is an image_url with a local_path
                if part.get("type") == "image_url" and isinstance(part.get("image_url"), dict) and "local_path" in part["image_url"]:
                    local_path = part["image_url"]["local_path"]
                    print(f"Encoding image '{local_path}' for VLM request...")
                    try:
                        # Use the utility function to encode the image
                        base64_data = encode_image(local_path)
                        if base64_data:
                            # Replace the local_path structure with the base64 URL structure expected by VLM
                            new_content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_data}"}})
                            print(f"Successfully encoded '{local_path}'.")
                        else:
                            # Handle encoding failure: Skip the image part for this VLM call
                            print(f"Warning: Could not encode image '{local_path}'. Skipping this image for the VLM request.")
                            # Optional: Add a text message indicating the failure to the VLM input
                            # new_content_parts.append({"type": "text", "text": f"[Note: Failed to load and encode image {local_path}]"})

                    except Exception as e:
                         print(f"Error during encoding image '{local_path}' for VLM: {e}")
                         # Handle error during encoding, e.g., skip the image part
                         # Optional: Add error text message
                         # new_content_parts.append({"type": "text", "text": f"[Error processing image {local_path}]"})
                else:
                    # Keep other content parts as they are (text, already encoded images if any in unexpected format)
                    new_content_parts.append(part)
            # Update the content of the message in the *temporary* list
            msg["content"] = new_content_parts
        # Note: Assistant messages or user messages without list content are not processed for image encoding here.

    # --- 4. Call VLM Request ---
    tries = 0
    ans_json = None
    content = None # Initialize content before the loop

    while ans_json is None and tries < 3:
        print(f"Attempt {tries + 1} to get valid JSON from VLM...")
        try:
            # Call VLM with the temporary messages list containing base64 images
            content = vlm_request(messages_for_vlm, optim=True)

            # Process the response to extract potential JSON
            answer_part = content.split('</think>')[-1]

            # Try extracting and parsing JSON
            ans_json = extract_and_parse_json(answer_part)

            # If primary attempt fails, try alternative extraction
            if ans_json is None:
                print("Primary JSON extraction failed, trying alternative method...")
                ans_json = extract_json_from_string(answer_part)

        except Exception as e:
            print(f"An error occurred during VLM request or response processing: {e}")
            ans_json = None
            content = None

        tries = tries + 1

    # --- End VLM Request and JSON Extraction Loop ---

    # 5. Append the assistant's last response to the *original* messages history
    # This keeps the history complete for the *next* potential turn,
    # and keeps the image paths format for user messages.
    if content is not None:
         assistant_message = {
             "role": "assistant",
             "content": content # Store the raw response content
         }
         messages.append(assistant_message)
         print("Appended assistant message to history.")
    else:
        print("VLM request failed to return content after multiple retries. No assistant message added to history.")


    # Check for 'optimized_layout' key in the final parsed JSON if successful
    if isinstance(ans_json, dict) and 'optimized_layout' in ans_json:
        print("Found 'optimized_layout' key in the parsed JSON.")
        ans_json = ans_json['optimized_layout']
    elif ans_json is not None:
         print("Parsed JSON is not a dict with 'optimized_layout' key, using the raw parsed JSON.")
         pass
    else:
        print("Failed to extract/parse any valid JSON.")


    # 6. Prepare the data to return, including the full messages history (with paths)
    data = {
        'ans_json': ans_json, # The final extracted JSON data (could be None)
        'content': content, # The raw content from the last VLM attempt (could be None)
        'messages': messages # The full conversation history list (image paths stored)
    }

    return data

def find_nonzero_bounding_box(vector):
  """
  检测numpy向量（数组）中非零区域的边界框。

  Args:
    vector: 一个 NumPy 数组。

  Returns:
    如果向量中存在非零元素，则返回一个包含 (x_min, y_min, x_max, y_max) 的元组。
    如果向量中所有元素都为零，则返回 None。
  """
  # 检查输入是否为 NumPy 数组
  if not isinstance(vector, np.ndarray):
    raise TypeError("输入必须是 NumPy 数组")

  # 检查数组维度是否为 2
  if vector.ndim != 2:
      raise ValueError("输入数组必须是二维的")

  # 找到所有非零元素的索引
  non_zero_indices = np.nonzero(vector)

  # non_zero_indices 是一个包含两个数组的元组：
  # 第一个数组是行索引 (y 坐标)
  # 第二个数组是列索引 (x 坐标)
  y_indices = non_zero_indices[0]
  x_indices = non_zero_indices[1]

  # 检查是否存在非零元素
  if len(y_indices) == 0:
    # 如果没有非零元素，则返回 None
    return None

  # 计算 x 和 y 坐标的最小值和最大值
  y_min = np.min(y_indices)
  y_max = np.max(y_indices)
  x_min = np.min(x_indices)
  x_max = np.max(x_indices)

  return (x_min, y_min, x_max, y_max)

def entity_center(x_min, y_min, x_max, y_max, shape, step=0.1, c_max=0.05, c_min=0.0):
    h, w = shape
    if x_min>h*c_max and y_min>w*c_max and x_max<h*(1-c_max) and y_max<w*(1-c_max):
        return step
    if x_min<h*c_min or y_min<w*c_min or x_max>h*(1-c_min) or y_max>w*(1-c_min):
        return -step * 5
    else:
        return 0

def layout_normalize(ans_json):
    scene_size = ans_json['scene_parameters']['scene_size']
    cam_pitch_angle = max(10, ans_json['scene_parameters']['camera_pitch_angle'])

    ans_json['scene_parameters']['scene_size'] = 1
    for i, entity in enumerate(ans_json['entity_layout']):
        entity['position'] = [p / scene_size for p in entity['position']]
        entity['size'] = [s / scene_size for s in entity['size']]
    return ans_json

def generate_scene(ans_json_for_scene, total_move=None, total_move_y=None):
    import copy
    ans_json = copy.deepcopy(ans_json_for_scene)
    ans_json = layout_normalize(ans_json)

    scene_size = ans_json['scene_parameters']['scene_size'] / 2
    cam_pitch_angle = 90 - ans_json['scene_parameters']['camera_pitch_angle']
    # cam_pitch_angle = 90
    floor_scale_x = 1
    floor_scale_y = 1

    y_min = 100
    y_max = 0
    for i, entity in enumerate(ans_json['entity_layout']):
        y_min = min(y_min, entity['position'][1] - entity['size'][2]/2)
        y_max = max(y_max, entity['position'][1] + entity['size'][2]/2)
    floor_offset = - (y_max + y_min) / 2

    # x_min = 100
    # x_max = 0
    # for i, entity in enumerate(ans_json['entity_layout']):
    #     x_min = min(x_min, entity['position'][0] - entity['size'][0]/2)
    #     x_max = max(x_max, entity['position'][0] + entity['size'][0]/2)
    # x_mean = (x_max + x_min) / 2
    # for i, entity in enumerate(ans_json['entity_layout']):
    #     entity['position'][0] -= x_mean

    # Build the scene    
    scene = DiffusionScene(scene_size=scene_size, fov=(60,60))
    scene.move_camera(rotation_angle=cam_pitch_angle,rotation_axis=[1,0,0], translation=[0,0,0])# rotation_axis(x,z,y), translation(x, z, y)
    # scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,-2*scene_size,0])# rotation_axis(x,z,y), translation(x, z, y)
    scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)

    for i, entity in enumerate(ans_json['entity_layout']):
        scene.add_box(id=f"box_{i}", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])
        # scene.box(f"box_{i}").rotate_left(entity['orient'])
        # mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id="box_2")

    if total_move_y is not None: scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,0,total_move_y])# rotation_axis(x,z,y), translation(x, z, y)
    if total_move is None:
        num = 0
        # total_move = 0
        total_move = -0.68 * scene_size - scene_size
        depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
        x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
        move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
        while move != 0 and num < 40:
            scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,move,0])# rotation_axis(x,z,y), translation(x, z, y)
            depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
            x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
            move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
            num += 1
            total_move += move
    else:
        scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,total_move,0])# rotation_axis(x,z,y), translation(x, z, y)

    depth_all = scene.render(single=True, floor=False, depth_max=4*scene_size)
    return depth_all, total_move

def prepare_data(pipe, data, depth_all, condition_size=512):
    eligen_entity_prompts = [entity['entity_name'] for entity in data["entity_layout"]]

    condition_imgs = []
    for depth in depth_all[:-1]:
        # depth = np.where(depth==depth_all[-1], depth, 0)
        condition_imgs.append(Image.fromarray(depth).convert("RGB").resize((condition_size, condition_size)))
                    
    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for img in condition_imgs:
        # Create downsampled mask for model input
        mask = np.array(img.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create full resolution mask for visualization
        mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
        eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))

    # Convert images to tensors and sort by depth
    condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])

    # Create final condition object
    condition_data = {
        "condition": condition_imgs,
        "eligen_entity_prompts": eligen_entity_prompts,
        "eligen_entity_masks": eligen_entity_masks,
        'eligen_entity_masks_pil': eligen_entity_masks_pil,
    }

    condition_ = Condition(
        condition_type='eligen_loose',
        condition=condition_data,
        position_delta=[0, 0],
    )
    return condition_, eligen_entity_prompts, eligen_entity_masks, eligen_entity_masks_pil

def extract_json_from_string(text_content):
    """
    Extracts a JSON object from a string that contains a JSON block
    enclosed in ```json ... ```.

    Args:
        text_content (str): The string containing the JSON block.

    Returns:
        dict: The parsed JSON object if found, otherwise None.
    """
    import re

    # Regex to find the JSON block
    # It looks for ```json followed by any characters (non-greedy) and then ```
    # The re.DOTALL flag makes '.' match newlines as well.
    match = re.search(r"```json\s*(.*?)\s*```", text_content, re.DOTALL)

    if match:
        json_string = match.group(1)
        try:
            # Remove any comments within the JSON string before parsing
            # This regex removes lines starting with //
            json_string_no_comments = re.sub(r"//.*", "", json_string)
            json_data = json.loads(json_string_no_comments)
            return json_data
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            # Attempt to fix common issues like trailing commas if needed,
            # but for now, just report the error.
            # For this specific example, the provided JSON is valid after removing comments.
            return None
    else:
        return None